diff --git a/ps_fuzz/chat_clients.py b/ps_fuzz/chat_clients.py index 533361b..04e3df5 100644 --- a/ps_fuzz/chat_clients.py +++ b/ps_fuzz/chat_clients.py @@ -1,4 +1,7 @@ from .langchain_integration import get_langchain_chat_models_info +import importlib +import importlib.util +import os from langchain_core.language_models.chat_models import BaseChatModel from langchain_core.outputs.llm_result import LLMResult from langchain.schema import BaseMessage, HumanMessage, SystemMessage, AIMessage @@ -50,6 +53,36 @@ def interact(self, history: MessageList, messages: MessageList) -> BaseMessage: history += [response_message] return response_message.content +# Custom chat client using a lightweight transformers pipeline +class ClientCustom(ClientBase): + """Chat model wrapper around a local transformers pipeline""" + def __init__(self, model_name: str): + module_filename = f"{model_name}.py" + if os.path.isfile(module_filename): + spec = importlib.util.spec_from_file_location(model_name, module_filename) + custom_module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(custom_module) + self.custom_module = custom_module + logger.info(f"Custom client loaded from file '{module_filename}'") + else: + self.custom_module = importlib.import_module(model_name) + logger.info(f"Custom client loaded from module 'model_name'") + self.custom_client = self.custom_module.initialize_client() + + def interact(self, history: MessageList, messages: MessageList) -> BaseMessage: + history += messages + + prompt = "" + for msg in reversed(messages): + if isinstance(msg, HumanMessage): + prompt = msg.content + break + + response = self.custom_module.generate(self.custom_client, prompt) + + history.append(AIMessage(content=response)) + return response + # Chat session allows chatting against target client while maintaining state (history buffer) class ChatSession: "Maintains single conversation, including history, and supports an optional system prompts" diff --git a/ps_fuzz/cli.py b/ps_fuzz/cli.py index 217ce4a..c8e17fb 100644 --- a/ps_fuzz/cli.py +++ b/ps_fuzz/cli.py @@ -37,6 +37,7 @@ def main(): print("Available providers:") for provider_name, provider_info in get_langchain_chat_models_info().items(): print(f" {BRIGHT}{provider_name}{RESET}: {provider_info.short_doc}") + print(f" {BRIGHT}custom{RESET}: Custom provider") sys.exit(0) if args.list_attacks: diff --git a/ps_fuzz/custom.py b/ps_fuzz/custom.py new file mode 100644 index 0000000..bd7f30f --- /dev/null +++ b/ps_fuzz/custom.py @@ -0,0 +1,21 @@ +from transformers import pipeline +import logging +logger = logging.getLogger(__name__) + +# Sample custom LLM integration using a lightweight model + +MODEL_NAME = "google/flan-t5-small" + +def initialize_client(): + """Initialize a text2text-generation pipeline for instruction-tuned models.""" + return pipeline("text2text-generation", model=MODEL_NAME) + + +def generate(client, prompt: str): + """Generate a response using the provided pipeline.""" + result = client(prompt, max_new_tokens=100) + generated_text = result[0]["generated_text"] + response = generated_text.strip() + logger.debug(f"Prompt: {prompt}") + logger.debug(f"Generated text: {response}") + return response diff --git a/ps_fuzz/interactive_mode.py b/ps_fuzz/interactive_mode.py index 0d37f5e..533a1b0 100644 --- a/ps_fuzz/interactive_mode.py +++ b/ps_fuzz/interactive_mode.py @@ -101,7 +101,7 @@ def show(cls, state: AppConfig): class TargetLLMOptions: @classmethod def show(cls, state: AppConfig): - models_list = get_langchain_chat_models_info().keys() + models_list = list(get_langchain_chat_models_info().keys()) + ['custom'] print("Target LLM Options: Review and modify the target LLM configuration") print("------------------------------------------------------------------") result = inquirer.prompt([ @@ -124,7 +124,7 @@ def show(cls, state: AppConfig): class AttackLLMOptions: @classmethod def show(cls, state: AppConfig): - models_list = get_langchain_chat_models_info().keys() + models_list = list(get_langchain_chat_models_info().keys()) + ['custom'] print("Attack LLM Options: Review and modify the service LLM configuration used by the tool to help attack the system prompt") print("---------------------------------------------------------------------------------------------------------------------") result = inquirer.prompt([ diff --git a/ps_fuzz/prompt_injection_fuzzer.py b/ps_fuzz/prompt_injection_fuzzer.py index 6e858fa..d5ca003 100644 --- a/ps_fuzz/prompt_injection_fuzzer.py +++ b/ps_fuzz/prompt_injection_fuzzer.py @@ -155,7 +155,10 @@ def run_interactive_chat(app_config: AppConfig): app_config.print_as_table() target_system_prompt = app_config.system_prompt try: - target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0) + if app_config.target_provider == "custom": + target_client = ClientCustom(app_config.target_model) + else: + target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0) interactive_chat(client=target_client, system_prompts=[target_system_prompt]) except (ModuleNotFoundError, ValidationError) as e: logger.warning(f"Error accessing the Target LLM provider {app_config.target_provider} with model '{app_config.target_model}': {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}") @@ -167,15 +170,22 @@ def run_fuzzer(app_config: AppConfig): custom_benchmark = app_config.custom_benchmark target_system_prompt = app_config.system_prompt try: - target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0) + if app_config.target_provider == "custom": + target_client = ClientCustom(app_config.target_model) + else: + target_client = ClientLangChain(app_config.target_provider, model=app_config.target_model, temperature=0) except (ModuleNotFoundError, ValidationError) as e: logger.warning(f"Error accessing the Target LLM provider {app_config.target_provider} with model '{app_config.target_model}': {colorama.Fore.RED}{e}{colorama.Style.RESET_ALL}") return client_config = ClientConfig(target_client, [target_system_prompt], custom_benchmark=custom_benchmark) try: + if app_config.attack_provider == "custom": + attack_client = ClientCustom(app_config.attack_model) + else: + attack_client = ClientLangChain(app_config.attack_provider, model=app_config.attack_model, temperature=app_config.attack_temperature) attack_config = AttackConfig( - attack_client = ClientLangChain(app_config.attack_provider, model=app_config.attack_model, temperature=app_config.attack_temperature), + attack_client = attack_client, attack_prompts_count = app_config.num_attempts ) except (ModuleNotFoundError, ValidationError) as e: