Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions ps_fuzz/chat_clients.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from .langchain_integration import get_langchain_chat_models_info
import importlib
import importlib.util
import os
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.outputs.llm_result import LLMResult
from langchain.schema import BaseMessage, HumanMessage, SystemMessage, AIMessage
Expand Down Expand Up @@ -50,6 +53,36 @@ def interact(self, history: MessageList, messages: MessageList) -> BaseMessage:
history += [response_message]
return response_message.content

# Custom chat client using a lightweight transformers pipeline
class ClientCustom(ClientBase):
"""Chat model wrapper around a local transformers pipeline"""
def __init__(self, model_name: str):
module_filename = f"{model_name}.py"
if os.path.isfile(module_filename):
spec = importlib.util.spec_from_file_location(model_name, module_filename)
custom_module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(custom_module)
self.custom_module = custom_module
logger.info(f"Custom client loaded from file '{module_filename}'")
else:
self.custom_module = importlib.import_module(model_name)
logger.info(f"Custom client loaded from module 'model_name'")
self.custom_client = self.custom_module.initialize_client()

def interact(self, history: MessageList, messages: MessageList) -> BaseMessage:
history += messages

prompt = ""
for msg in reversed(messages):
if isinstance(msg, HumanMessage):
prompt = msg.content
break

response = self.custom_module.generate(self.custom_client, prompt)

history.append(AIMessage(content=response))
return response

# Chat session allows chatting against target client while maintaining state (history buffer)
class ChatSession:
"Maintains single conversation, including history, and supports an optional system prompts"
Expand Down
1 change: 1 addition & 0 deletions ps_fuzz/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ def main():
print("Available providers:")
for provider_name, provider_info in get_langchain_chat_models_info().items():
print(f" {BRIGHT}{provider_name}{RESET}: {provider_info.short_doc}")
print(f" {BRIGHT}custom{RESET}: Custom provider")
sys.exit(0)

if args.list_attacks:
Expand Down
21 changes: 21 additions & 0 deletions ps_fuzz/custom.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from transformers import pipeline
import logging
logger = logging.getLogger(__name__)

# Sample custom LLM integration using a lightweight model

MODEL_NAME = "google/flan-t5-small"

def initialize_client():
"""Initialize a text2text-generation pipeline for instruction-tuned models."""
return pipeline("text2text-generation", model=MODEL_NAME)


def generate(client, prompt: str):
"""Generate a response using the provided pipeline."""
result = client(prompt, max_new_tokens=100)
generated_text = result[0]["generated_text"]
response = generated_text.strip()
logger.debug(f"Prompt: {prompt}")
logger.debug(f"Generated text: {response}")
return response
4 changes: 2 additions & 2 deletions ps_fuzz/interactive_mode.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def show(cls, state: AppConfig):
class TargetLLMOptions:
@classmethod
def show(cls, state: AppConfig):
models_list = get_langchain_chat_models_info().keys()
models_list = list(get_langchain_chat_models_info().keys()) + ['custom']
print("Target LLM Options: Review and modify the target LLM configuration")
print("------------------------------------------------------------------")
result = inquirer.prompt([
Expand All @@ -124,7 +124,7 @@ def show(cls, state: AppConfig):
class AttackLLMOptions:
@classmethod
def show(cls, state: AppConfig):
models_list = get_langchain_chat_models_info().keys()
models_list = list(get_langchain_chat_models_info().keys()) + ['custom']
print("Attack LLM Options: Review and modify the service LLM configuration used by the tool to help attack the system prompt")
print("---------------------------------------------------------------------------------------------------------------------")
result = inquirer.prompt([
Expand Down
16 changes: 13 additions & 3 deletions ps_fuzz/prompt_injection_fuzzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,10 @@ def run_interactive_chat(app_config: AppConfig):
app_config.print_as_table()
target_system_prompt = app_config.system_prompt
try:
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0)
if app_config.target_provider == "custom":
target_client = ClientCustom(app_config.target_model)
else:
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0)
interactive_chat(client=target_client, system_prompts=[target_system_prompt])
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(f"Error accessing the Target LLM provider {app_config.target_provider} with model '{app_config.target_model}': {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
Expand All @@ -167,15 +170,22 @@ def run_fuzzer(app_config: AppConfig):
custom_benchmark = app_config.custom_benchmark
target_system_prompt = app_config.system_prompt
try:
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0)
if app_config.target_provider == "custom":
target_client = ClientCustom(app_config.target_model)
else:
target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0)
except (ModuleNotFoundError, ValidationError) as e:
logger.warning(f"Error accessing the Target LLM provider {app_config.target_provider} with model '{app_config.target_model}': {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}")
return
client_config = ClientConfig(target_client, [target_system_prompt], custom_benchmark=custom_benchmark)

try:
if app_config.attack_provider == "custom":
attack_client = ClientCustom(app_config.attack_model)
else:
attack_client = ClientLangChain(app_config.attack_provider, model=app_config.attack_model, temperature=app_config.attack_temperature)
attack_config = AttackConfig(
attack_client = ClientLangChain(app_config.attack_provider, model=app_config.attack_model, temperature=app_config.attack_temperature),
attack_client = attack_client,
attack_prompts_count = app_config.num_attempts
)
except (ModuleNotFoundError, ValidationError) as e:
Expand Down