# +------------------------------------+ # # Prompt Injection Detection # # +------------------------------------+ # Thank you users! We ❤️ you! - Krrish & Ishaan ## Reject a call if it contains a prompt injection attack. from difflib import SequenceMatcher from typing import List, Literal, Optional from fastapi import HTTPException import litellm from litellm._logging import verbose_proxy_logger from litellm.caching.caching import DualCache from litellm.integrations.custom_logger import CustomLogger from litellm.litellm_core_utils.prompt_templates.factory import ( prompt_injection_detection_default_pt, ) from litellm.proxy._types import LiteLLMPromptInjectionParams, UserAPIKeyAuth from litellm.router import Router from litellm.utils import get_formatted_prompt class _OPTIONAL_PromptInjectionDetection(CustomLogger): # Class variables or attributes def __init__( self, prompt_injection_params: Optional[LiteLLMPromptInjectionParams] = None, ): self.prompt_injection_params = prompt_injection_params self.llm_router: Optional[Router] = None self.verbs = [ "Ignore", "Disregard", "Skip", "Forget", "Neglect", "Overlook", "Omit", "Bypass", "Pay no attention to", "Do not follow", "Do not obey", ] self.adjectives = [ "", "prior", "previous", "preceding", "above", "foregoing", "earlier", "initial", ] self.prepositions = [ "", "and start over", "and start anew", "and begin afresh", "and start from scratch", ] def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"): if level == "INFO": verbose_proxy_logger.info(print_statement) elif level == "DEBUG": verbose_proxy_logger.debug(print_statement) if litellm.set_verbose is True: print(print_statement) # noqa def update_environment(self, router: Optional[Router] = None): self.llm_router = router if ( self.prompt_injection_params is not None and self.prompt_injection_params.llm_api_check is True ): if self.llm_router is None: raise Exception( "PromptInjectionDetection: Model List not set. Required for Prompt Injection detection." ) self.print_verbose( f"model_names: {self.llm_router.model_names}; self.prompt_injection_params.llm_api_name: {self.prompt_injection_params.llm_api_name}" ) if ( self.prompt_injection_params.llm_api_name is None or self.prompt_injection_params.llm_api_name not in self.llm_router.model_names ): raise Exception( "PromptInjectionDetection: Invalid LLM API Name. LLM API Name must be a 'model_name' in 'model_list'." ) def generate_injection_keywords(self) -> List[str]: combinations = [] for verb in self.verbs: for adj in self.adjectives: for prep in self.prepositions: phrase = " ".join(filter(None, [verb, adj, prep])).strip() if ( len(phrase.split()) > 2 ): # additional check to ensure more than 2 words combinations.append(phrase.lower()) return combinations def check_user_input_similarity( self, user_input: str, similarity_threshold: float = 0.7 ) -> bool: user_input_lower = user_input.lower() keywords = self.generate_injection_keywords() for keyword in keywords: # Calculate the length of the keyword to extract substrings of the same length from user input keyword_length = len(keyword) for i in range(len(user_input_lower) - keyword_length + 1): # Extract a substring of the same length as the keyword substring = user_input_lower[i : i + keyword_length] # Calculate similarity match_ratio = SequenceMatcher(None, substring, keyword).ratio() if match_ratio > similarity_threshold: self.print_verbose( print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}", level="INFO", ) return True # Found a highly similar substring return False # No substring crossed the threshold async def async_pre_call_hook( self, user_api_key_dict: UserAPIKeyAuth, cache: DualCache, data: dict, call_type: str, # "completion", "embeddings", "image_generation", "moderation" ): try: """ - check if user id part of call - check if user id part of blocked list """ self.print_verbose("Inside Prompt Injection Detection Pre-Call Hook") try: assert call_type in [ "completion", "text_completion", "embeddings", "image_generation", "moderation", "audio_transcription", ] except Exception: self.print_verbose( f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']" ) return data formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore is_prompt_attack = False if self.prompt_injection_params is not None: # 1. check if heuristics check turned on if self.prompt_injection_params.heuristics_check is True: is_prompt_attack = self.check_user_input_similarity( user_input=formatted_prompt ) if is_prompt_attack is True: raise HTTPException( status_code=400, detail={ "error": "Rejected message. This is a prompt injection attack." }, ) # 2. check if vector db similarity check turned on [TODO] Not Implemented yet if self.prompt_injection_params.vector_db_check is True: pass else: is_prompt_attack = self.check_user_input_similarity( user_input=formatted_prompt ) if is_prompt_attack is True: raise HTTPException( status_code=400, detail={ "error": "Rejected message. This is a prompt injection attack." }, ) return data except HTTPException as e: if ( e.status_code == 400 and isinstance(e.detail, dict) and "error" in e.detail # type: ignore and self.prompt_injection_params is not None and self.prompt_injection_params.reject_as_response ): return e.detail.get("error") raise e except Exception as e: verbose_proxy_logger.exception( "litellm.proxy.hooks.prompt_injection_detection.py::async_pre_call_hook(): Exception occured - {}".format( str(e) ) ) async def async_moderation_hook( # type: ignore self, data: dict, user_api_key_dict: UserAPIKeyAuth, call_type: Literal[ "completion", "embeddings", "image_generation", "moderation", "audio_transcription", ], ) -> Optional[bool]: self.print_verbose( f"IN ASYNC MODERATION HOOK - self.prompt_injection_params = {self.prompt_injection_params}" ) if self.prompt_injection_params is None: return None formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore is_prompt_attack = False prompt_injection_system_prompt = getattr( self.prompt_injection_params, "llm_api_system_prompt", prompt_injection_detection_default_pt(), ) # 3. check if llm api check turned on if ( self.prompt_injection_params.llm_api_check is True and self.prompt_injection_params.llm_api_name is not None and self.llm_router is not None ): # make a call to the llm api response = await self.llm_router.acompletion( model=self.prompt_injection_params.llm_api_name, messages=[ { "role": "system", "content": prompt_injection_system_prompt, }, {"role": "user", "content": formatted_prompt}, ], ) self.print_verbose(f"Received LLM Moderation response: {response}") self.print_verbose( f"llm_api_fail_call_string: {self.prompt_injection_params.llm_api_fail_call_string}" ) if isinstance(response, litellm.ModelResponse) and isinstance( response.choices[0], litellm.Choices ): if self.prompt_injection_params.llm_api_fail_call_string in response.choices[0].message.content: # type: ignore is_prompt_attack = True if is_prompt_attack is True: raise HTTPException( status_code=400, detail={ "error": "Rejected message. This is a prompt injection attack." }, ) return is_prompt_attack