Spaces:
Running
Running
from typing import Optional | |
import torch | |
import torch.nn.functional as F | |
import weave | |
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
from ..base import Guardrail | |
class PromptInjectionLlamaGuardrail(Guardrail): | |
""" | |
A guardrail class designed to detect and mitigate prompt injection attacks | |
using a pre-trained language model. This class leverages a sequence | |
classification model to evaluate prompts for potential security threats | |
such as jailbreak attempts and indirect injection attempts. | |
Attributes: | |
model_name (str): The name of the pre-trained model used for sequence | |
classification. | |
max_sequence_length (int): The maximum length of the input sequence | |
for the tokenizer. | |
temperature (float): A scaling factor for the model's logits to | |
control the randomness of predictions. | |
jailbreak_score_threshold (float): The threshold above which a prompt | |
is considered a jailbreak attempt. | |
indirect_injection_score_threshold (float): The threshold above which | |
a prompt is considered an indirect injection attempt. | |
""" | |
model_name: str = "meta-llama/Prompt-Guard-86M" | |
max_sequence_length: int = 512 | |
temperature: float = 1.0 | |
jailbreak_score_threshold: float = 0.5 | |
indirect_injection_score_threshold: float = 0.5 | |
_tokenizer: Optional[AutoTokenizer] = None | |
_model: Optional[AutoModelForSequenceClassification] = None | |
def model_post_init(self, __context): | |
self._tokenizer = AutoTokenizer.from_pretrained(self.model_name) | |
self._model = AutoModelForSequenceClassification.from_pretrained( | |
self.model_name | |
) | |
def get_class_probabilities(self, prompt): | |
inputs = self._tokenizer( | |
prompt, | |
return_tensors="pt", | |
padding=True, | |
truncation=True, | |
max_length=self.max_sequence_length, | |
) | |
with torch.no_grad(): | |
logits = self._model(**inputs).logits | |
scaled_logits = logits / self.temperature | |
probabilities = F.softmax(scaled_logits, dim=-1) | |
return probabilities | |
def get_score(self, prompt: str): | |
probabilities = self.get_class_probabilities(prompt) | |
return { | |
"jailbreak_score": probabilities[0, 2].item(), | |
"indirect_injection_score": ( | |
probabilities[0, 1] + probabilities[0, 2] | |
).item(), | |
} | |
""" | |
Analyzes a given prompt to determine its safety by evaluating the likelihood | |
of it being a jailbreak or indirect injection attempt. | |
This function utilizes the `get_score` method to obtain the probabilities | |
associated with the prompt being a jailbreak or indirect injection attempt. | |
It then compares these probabilities against predefined thresholds to assess | |
the prompt's safety. If the `jailbreak_score` exceeds the `jailbreak_score_threshold`, | |
the prompt is flagged as a potential jailbreak attempt, and a confidence level | |
is calculated and included in the summary. Similarly, if the `indirect_injection_score` | |
surpasses the `indirect_injection_score_threshold`, the prompt is flagged as a potential | |
indirect injection attempt, with its confidence level also included in the summary. | |
Returns a dictionary containing: | |
- "safe": A boolean indicating whether the prompt is considered safe | |
(i.e., both scores are below their respective thresholds). | |
- "summary": A string summarizing the findings, including confidence levels | |
for any detected threats. | |
""" | |
def guard(self, prompt: str): | |
score = self.get_score(prompt) | |
summary = "" | |
if score["jailbreak_score"] > self.jailbreak_score_threshold: | |
confidence = round(score["jailbreak_score"] * 100, 2) | |
summary += f"Prompt is deemed to be a jailbreak attempt with {confidence}% confidence." | |
if score["indirect_injection_score"] > self.indirect_injection_score_threshold: | |
confidence = round(score["indirect_injection_score"] * 100, 2) | |
summary += f" Prompt is deemed to be an indirect injection attempt with {confidence}% confidence." | |
return { | |
"safe": score["jailbreak_score"] < self.jailbreak_score_threshold | |
and score["indirect_injection_score"] | |
< self.indirect_injection_score_threshold, | |
"summary": summary.strip(), | |
} | |
def predict(self, prompt: str): | |
return self.guard(prompt) | |