--- library_name: transformers license: apache-2.0 base_model: google/electra-base-discriminator tags: - generated_from_trainer model-index: - name: JailBreakModel results: [] --- # ELECTRA Trainer for Prompt Injection Detection colab notebook : https://colab.research.google.com/drive/11da3m_gYwmkURcjGn8_kp23GiM-INDrm?usp=sharing ## Overview This repository contains a fine-tuned ELECTRA model designed for detecting prompt injections in AI systems. The model classifies input prompts into two categories: benign and jailbreak. This approach aims to enhance the safety and robustness of AI applications. ## Approach and Design Decisions The primary goal of this project was to create a reliable model that can distinguish between safe and potentially harmful prompts. Key design decisions included: - *Model Selection*: I chose the ELECTRA model due to its efficient training process and strong performance on text classification tasks. ELECTRA's architecture allows for effective learning from limited data, which is crucial given the specificity of the task. - *Data Preparation*: A custom dataset was curated, consisting of diverse prompts labeled as either benign or jailbreak. The dataset aimed to balance both classes to mitigate biases during training. - *Long Inputs*: To handle prompts exceeding the maximum input length of the ELECTRA model, I used truncation. Even though there was a data loss , the model still managed to classify the prompt correctly. ## Model Architecture and Training Strategy The model is based on the google/electra-base-discriminator architecture. Here’s an overview of the training strategy: 1. *Tokenization*: I utilized the ELECTRA tokenizer to prepare input prompts. Padding and truncation were handled to ensure uniform input size. 2. *Training Configuration*: - *Learning Rate*: Set to 5e-05 for stable convergence. - *Batch Size*: A batch size of 16 was chosen to balance training speed and memory usage. - *Epochs*: The model was trained for 2 epochs to prevent overfitting while still allowing sufficient learning from the dataset. 3. *Evaluation*: The model’s performance was evaluated on a validation set, focusing on metrics such as accuracy, precision, recall, and F1 score. ## Key Results and Observations - The model achieved a high accuracy rate on the validation set, indicating its effectiveness in distinguishing between benign and harmful prompts. ## Instructions for Running the Inference Pipeline To run the inference pipeline for classifying prompts, follow these steps: 1. *Install Dependencies*: Ensure you have Python installed, and then install the required libraries using pip: ```bash pip install transformers datasets torch ```bash # Load model directly from transformers import AutoTokenizer, AutoModelForSequenceClassification Tokenizer = AutoTokenizer.from_pretrained("idanpers/JailBreakModel") model = AutoModelForSequenceClassification.from_pretrained("idanpers/JailBreakModel") training_args = TrainingArguments( output_dir="./results", per_device_train_batch_size=16, per_device_eval_batch_size=16, report_to="none", # Disable W&B save_safetensors=False, ) # Create Trainer instance trainer = Trainer( model=model, args=training_args, tokenizer=tokenizer, ) use: def classify_prompt(prompt): # Error handling for empty input if not isinstance(prompt, str) or prompt.strip() == "": return {"error": "Invalid input. Please provide a non-empty text prompt."} # Tokenize the input prompt and convert to dataset format expected by trainer.predict inputs = Tokenizer(prompt, return_tensors="pt", padding=True, truncation=True) dataset = Dataset.from_dict({"input_ids": inputs["input_ids"], "attention_mask": inputs["attention_mask"]}) # Use trainer.predict to classify prediction_output = trainer.predict(dataset) # Get the softmax probabilities for confidence scores probs = torch.softmax(torch.tensor(prediction_output.predictions), dim=1).cpu().numpy() confidence = np.max(probs) pred_label = np.argmax(probs, axis=1)[0] # Map prediction to label label = "PROMPT_INJECTION" if pred_label == 1 else "BENIGN" return {"label": label, "confidence": confidence} #Accept input from the user and classify it prompt = input("Enter a prompt for classification: ") result = classify_prompt(prompt) #Check for errors before accessing the classification result if "error" in result: print(f"Error: {result['error']}") else: print(f"Classification Result: {result['label']}") print(f"Confidence Score: {result['confidence']:.2f}")