|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from difflib import SequenceMatcher |
|
from typing import List, Literal, Optional |
|
|
|
from fastapi import HTTPException |
|
|
|
import litellm |
|
from litellm._logging import verbose_proxy_logger |
|
from litellm.caching.caching import DualCache |
|
from litellm.integrations.custom_logger import CustomLogger |
|
from litellm.litellm_core_utils.prompt_templates.factory import ( |
|
prompt_injection_detection_default_pt, |
|
) |
|
from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth |
|
from litellm.router import Router |
|
from litellm.utils import get_formatted_prompt |
|
|
|
|
|
class _OPTIONAL_PromptInjectionDetection(CustomLogger): |
|
|
|
def __init__( |
|
self, |
|
prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, |
|
): |
|
self.prompt_injection_params = prompt_injection_params |
|
self.llm_router: Optional[Router] = None |
|
|
|
self.verbs = [ |
|
"Ignore", |
|
"Disregard", |
|
"Skip", |
|
"Forget", |
|
"Neglect", |
|
"Overlook", |
|
"Omit", |
|
"Bypass", |
|
"Pay no attention to", |
|
"Do not follow", |
|
"Do not obey", |
|
] |
|
self.adjectives = [ |
|
"", |
|
"prior", |
|
"previous", |
|
"preceding", |
|
"above", |
|
"foregoing", |
|
"earlier", |
|
"initial", |
|
] |
|
self.prepositions = [ |
|
"", |
|
"and start over", |
|
"and start anew", |
|
"and begin afresh", |
|
"and start from scratch", |
|
] |
|
|
|
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): |
|
if level == "INFO": |
|
verbose_proxy_logger.info(print_statement) |
|
elif level == "DEBUG": |
|
verbose_proxy_logger.debug(print_statement) |
|
|
|
if litellm.set_verbose is True: |
|
print(print_statement) |
|
|
|
def update_environment(self, router: Optional[Router] = None): |
|
self.llm_router = router |
|
|
|
if ( |
|
self.prompt_injection_params is not None |
|
and self.prompt_injection_params.llm_api_check is True |
|
): |
|
if self.llm_router is None: |
|
raise Exception( |
|
"PromptInjectionDetection: Model List not set. Required for Prompt Injection detection." |
|
) |
|
|
|
self.print_verbose( |
|
f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}" |
|
) |
|
if ( |
|
self.prompt_injection_params.llm_api_name is None |
|
or self.prompt_injection_params.llm_api_name |
|
not in self.llm_router.model_names |
|
): |
|
raise Exception( |
|
"PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'." |
|
) |
|
|
|
def generate_injection_keywords(self) -> List[str]: |
|
combinations = [] |
|
for verb in self.verbs: |
|
for adj in self.adjectives: |
|
for prep in self.prepositions: |
|
phrase = " ".join(filter(None, [verb, adj, prep])).strip() |
|
if ( |
|
len(phrase.split()) > 2 |
|
): |
|
combinations.append(phrase.lower()) |
|
return combinations |
|
|
|
def check_user_input_similarity( |
|
self, user_input: str, similarity_threshold: float = 0.7 |
|
) -> bool: |
|
user_input_lower = user_input.lower() |
|
keywords = self.generate_injection_keywords() |
|
|
|
for keyword in keywords: |
|
|
|
keyword_length = len(keyword) |
|
|
|
for i in range(len(user_input_lower) - keyword_length + 1): |
|
|
|
substring = user_input_lower[i : i + keyword_length] |
|
|
|
|
|
match_ratio = SequenceMatcher(None, substring, keyword).ratio() |
|
if match_ratio > similarity_threshold: |
|
self.print_verbose( |
|
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", |
|
level="INFO", |
|
) |
|
return True |
|
return False |
|
|
|
async def async_pre_call_hook( |
|
self, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
cache: DualCache, |
|
data: dict, |
|
call_type: str, |
|
): |
|
try: |
|
""" |
|
- check if user id part of call |
|
- check if user id part of blocked list |
|
""" |
|
self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook") |
|
try: |
|
assert call_type in [ |
|
"completion", |
|
"text_completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
] |
|
except Exception: |
|
self.print_verbose( |
|
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" |
|
) |
|
return data |
|
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) |
|
|
|
is_prompt_attack = False |
|
|
|
if self.prompt_injection_params is not None: |
|
|
|
if self.prompt_injection_params.heuristics_check is True: |
|
is_prompt_attack = self.check_user_input_similarity( |
|
user_input=formatted_prompt |
|
) |
|
if is_prompt_attack is True: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Rejected message. This is a prompt injection attack." |
|
}, |
|
) |
|
|
|
if self.prompt_injection_params.vector_db_check is True: |
|
pass |
|
else: |
|
is_prompt_attack = self.check_user_input_similarity( |
|
user_input=formatted_prompt |
|
) |
|
|
|
if is_prompt_attack is True: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Rejected message. This is a prompt injection attack." |
|
}, |
|
) |
|
|
|
return data |
|
|
|
except HTTPException as e: |
|
|
|
if ( |
|
e.status_code == 400 |
|
and isinstance(e.detail, dict) |
|
and "error" in e.detail |
|
and self.prompt_injection_params is not None |
|
and self.prompt_injection_params.reject_as_response |
|
): |
|
return e.detail.get("error") |
|
raise e |
|
except Exception as e: |
|
verbose_proxy_logger.exception( |
|
"litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( |
|
str(e) |
|
) |
|
) |
|
|
|
async def async_moderation_hook( |
|
self, |
|
data: dict, |
|
user_api_key_dict: UserAPIKeyAuth, |
|
call_type: Literal[ |
|
"completion", |
|
"embeddings", |
|
"image_generation", |
|
"moderation", |
|
"audio_transcription", |
|
], |
|
) -> Optional[bool]: |
|
self.print_verbose( |
|
f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}" |
|
) |
|
|
|
if self.prompt_injection_params is None: |
|
return None |
|
|
|
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) |
|
is_prompt_attack = False |
|
|
|
prompt_injection_system_prompt = getattr( |
|
self.prompt_injection_params, |
|
"llm_api_system_prompt", |
|
prompt_injection_detection_default_pt(), |
|
) |
|
|
|
|
|
if ( |
|
self.prompt_injection_params.llm_api_check is True |
|
and self.prompt_injection_params.llm_api_name is not None |
|
and self.llm_router is not None |
|
): |
|
|
|
response = await self.llm_router.acompletion( |
|
model=self.prompt_injection_params.llm_api_name, |
|
messages=[ |
|
{ |
|
"role": "system", |
|
"content": prompt_injection_system_prompt, |
|
}, |
|
{"role": "user", "content": formatted_prompt}, |
|
], |
|
) |
|
|
|
self.print_verbose(f"Received LLM Moderation response: {response}") |
|
self.print_verbose( |
|
f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}" |
|
) |
|
if isinstance(response, litellm.ModelResponse) and isinstance( |
|
response.choices[0], litellm.Choices |
|
): |
|
if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: |
|
is_prompt_attack = True |
|
|
|
if is_prompt_attack is True: |
|
raise HTTPException( |
|
status_code=400, |
|
detail={ |
|
"error": "Rejected message. This is a prompt injection attack." |
|
}, |
|
) |
|
|
|
return is_prompt_attack |
|
|