Skip to content

Commit 3e66b50

Browse files
Merge pull request #2498 from BerriAI/litellm_prompt_injection_detection
feat(prompt_injection_detection.py): support simple heuristic similarity check for prompt injection attacks
2 parents dbc7552 + 234cdbb commit 3e66b50

File tree

3 files changed

+190
-0
lines changed

3 files changed

+190
-0
lines changed
Lines changed: 144 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,144 @@
1+
# +------------------------------------+
2+
#
3+
# Prompt Injection Detection
4+
#
5+
# +------------------------------------+
6+
# Thank you users! We ❤️ you! - Krrish & Ishaan
7+
## Reject a call if it contains a prompt injection attack.
8+
9+
10+
from typing import Optional, Literal
11+
import litellm
12+
from litellm.caching import DualCache
13+
from litellm.proxy._types import UserAPIKeyAuth
14+
from litellm.integrations.custom_logger import CustomLogger
15+
from litellm._logging import verbose_proxy_logger
16+
from litellm.utils import get_formatted_prompt
17+
from fastapi import HTTPException
18+
import json, traceback, re
19+
from difflib import SequenceMatcher
20+
from typing import List
21+
22+
23+
class _ENTERPRISE_PromptInjectionDetection(CustomLogger):
24+
# Class variables or attributes
25+
def __init__(self):
26+
self.verbs = [
27+
"Ignore",
28+
"Disregard",
29+
"Skip",
30+
"Forget",
31+
"Neglect",
32+
"Overlook",
33+
"Omit",
34+
"Bypass",
35+
"Pay no attention to",
36+
"Do not follow",
37+
"Do not obey",
38+
]
39+
self.adjectives = [
40+
"",
41+
"prior",
42+
"previous",
43+
"preceding",
44+
"above",
45+
"foregoing",
46+
"earlier",
47+
"initial",
48+
]
49+
self.prepositions = [
50+
"",
51+
"and start over",
52+
"and start anew",
53+
"and begin afresh",
54+
"and start from scratch",
55+
]
56+
57+
def print_verbose(self, print_statement, level: Literal["INFO", "DEBUG"] = "DEBUG"):
58+
if level == "INFO":
59+
verbose_proxy_logger.info(print_statement)
60+
elif level == "DEBUG":
61+
verbose_proxy_logger.debug(print_statement)
62+
63+
if litellm.set_verbose is True:
64+
print(print_statement) # noqa
65+
66+
def generate_injection_keywords(self) -> List[str]:
67+
combinations = []
68+
for verb in self.verbs:
69+
for adj in self.adjectives:
70+
for prep in self.prepositions:
71+
phrase = " ".join(filter(None, [verb, adj, prep])).strip()
72+
combinations.append(phrase.lower())
73+
return combinations
74+
75+
def check_user_input_similarity(
76+
self, user_input: str, similarity_threshold: float = 0.7
77+
) -> bool:
78+
user_input_lower = user_input.lower()
79+
keywords = self.generate_injection_keywords()
80+
81+
for keyword in keywords:
82+
# Calculate the length of the keyword to extract substrings of the same length from user input
83+
keyword_length = len(keyword)
84+
85+
for i in range(len(user_input_lower) - keyword_length + 1):
86+
# Extract a substring of the same length as the keyword
87+
substring = user_input_lower[i : i + keyword_length]
88+
89+
# Calculate similarity
90+
match_ratio = SequenceMatcher(None, substring, keyword).ratio()
91+
if match_ratio > similarity_threshold:
92+
self.print_verbose(
93+
print_statement=f"Rejected user input - {user_input}. {match_ratio} similar to {keyword}",
94+
level="INFO",
95+
)
96+
return True # Found a highly similar substring
97+
return False # No substring crossed the threshold
98+
99+
async def async_pre_call_hook(
100+
self,
101+
user_api_key_dict: UserAPIKeyAuth,
102+
cache: DualCache,
103+
data: dict,
104+
call_type: str, # "completion", "embeddings", "image_generation", "moderation"
105+
):
106+
try:
107+
"""
108+
- check if user id part of call
109+
- check if user id part of blocked list
110+
"""
111+
self.print_verbose(f"Inside Prompt Injection Detection Pre-Call Hook")
112+
try:
113+
assert call_type in [
114+
"completion",
115+
"embeddings",
116+
"image_generation",
117+
"moderation",
118+
"audio_transcription",
119+
]
120+
except Exception as e:
121+
self.print_verbose(
122+
f"Call Type - {call_type}, not in accepted list - ['completion','embeddings','image_generation','moderation','audio_transcription']"
123+
)
124+
return data
125+
formatted_prompt = get_formatted_prompt(data=data, call_type=call_type) # type: ignore
126+
127+
is_prompt_attack = self.check_user_input_similarity(
128+
user_input=formatted_prompt
129+
)
130+
131+
if is_prompt_attack == True:
132+
raise HTTPException(
133+
status_code=400,
134+
detail={
135+
"error": "Rejected message. This is a prompt injection attack."
136+
},
137+
)
138+
139+
return data
140+
141+
except HTTPException as e:
142+
raise e
143+
except Exception as e:
144+
traceback.print_exc()

litellm/proxy/proxy_server.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1665,6 +1665,18 @@ async def load_config(
16651665

16661666
banned_keywords_obj = _ENTERPRISE_BannedKeywords()
16671667
imported_list.append(banned_keywords_obj)
1668+
elif (
1669+
isinstance(callback, str)
1670+
and callback == "detect_prompt_injection"
1671+
):
1672+
from litellm.proxy.enterprise.enterprise_hooks.prompt_injection_detection import (
1673+
_ENTERPRISE_PromptInjectionDetection,
1674+
)
1675+
1676+
prompt_injection_detection_obj = (
1677+
_ENTERPRISE_PromptInjectionDetection()
1678+
)
1679+
imported_list.append(prompt_injection_detection_obj)
16681680
else:
16691681
imported_list.append(
16701682
get_instance_fn(

litellm/utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5301,6 +5301,40 @@ def get_supported_openai_params(model: str, custom_llm_provider: str):
53015301
]
53025302

53035303

5304+
def get_formatted_prompt(
5305+
data: dict,
5306+
call_type: Literal[
5307+
"completion",
5308+
"embedding",
5309+
"image_generation",
5310+
"audio_transcription",
5311+
"moderation",
5312+
],
5313+
) -> str:
5314+
"""
5315+
Extracts the prompt from the input data based on the call type.
5316+
5317+
Returns a string.
5318+
"""
5319+
prompt = ""
5320+
if call_type == "completion":
5321+
for m in data["messages"]:
5322+
if "content" in m and isinstance(m["content"], str):
5323+
prompt += m["content"]
5324+
elif call_type == "embedding" or call_type == "moderation":
5325+
if isinstance(data["input"], str):
5326+
prompt = data["input"]
5327+
elif isinstance(data["input"], list):
5328+
for m in data["input"]:
5329+
prompt += m
5330+
elif call_type == "image_generation":
5331+
prompt = data["prompt"]
5332+
elif call_type == "audio_transcription":
5333+
if "prompt" in data:
5334+
prompt = data["prompt"]
5335+
return prompt
5336+
5337+
53045338
def get_llm_provider(
53055339
model: str,
53065340
custom_llm_provider: Optional[str] = None,

0 commit comments

Comments
 (0)