In this post, we will discuss how to build a Prompt Injection detector using a simple classification task with Scikit-learn’s Logistic Regression. Logistic Regression is a statistical method for binary classification problems. It helps predict situations with only two possible outcomes.
We will use SPML Chatbot Prompt Injection Dataset for input data.
Install the following libraries:
pip install datasets
pip install sentence-transformers
pip install scikit-learn
We will start by loading the dataset
from datasets import load_dataset
dataset = load_dataset("reshabhs/SPML_Chatbot_Prompt_Injection")
Let’s look at the dataset
dataset
DatasetDict({
train: Dataset({
features: ['System Prompt', 'User Prompt', 'Prompt injection', 'Degree', 'Source'],
num_rows: 16012
})
})
This displays the dataset structure. There are 16,012 records in this dataset, each with five columns:
System PromptUser PromptPrompt injectionDegreeSource
For feature extraction we will be using text embedding of System Prompt and User Prompt columns. We will create a new column named Query by concatenating the System Prompt and User Prompt columns.
def build_user_query(row):
row['Query'] = str(row['System Prompt']) + '\n' + str(row['User Prompt']) # Adjust the separator as needed
return row
dataset = dataset.map(build_user_query)
Now, we will split the dataset to training and test sets. This splits the data into an 80% training set and a 20% testing set.
split = dataset['train'].train_test_split(test_size=0.2)
training_set = split['train']
test_set = split['test']
We can look at each of the training and test sets.
The training_set has 12809 records.
Dataset({
features: ['System Prompt', 'User Prompt', 'Prompt injection', 'Degree', 'Source', 'query'],
num_rows: 12809
})
And, test_set has 3203 records.
Dataset({
features: ['System Prompt', 'User Prompt', 'Prompt injection', 'Degree', 'Source', 'query'],
num_rows: 3203
})
Now, we will embed the training and test datasets.
We will use the sentence-transformers/all-mpnet-base-v2 model for text embedding. This is a small and fast model that can run on CPUs. Text embedding converts text data into numerical vectors. This allows us to process text data using machine learning models.
from sentence_transformers import SentenceTransformer
# Load model
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
# Convert text to embeddings
train_embeddings = model.encode(training_set["query"], show_progress_bar=True)
test_embeddings = model.encode(test_set["query"], show_progress_bar=True)
On my machine it took around five minutes to generate embeddings of both training and test datasets.
We train the Logistic Regression model on the training set embeddings to classify prompts as either “Good Queries” or “Infected Queries” based on the presence of prompt injection.
from sklearn.linear_model import LogisticRegression
clf = LogisticRegression(random_state=42)
clf.fit(train_embeddings, training_set["Prompt injection"])
Now, we can test the model on our test dataset by running the code below.
from sklearn.metrics import classification_report
def generate_report(y_true, y_pred):
report = classification_report(
y_true, y_pred,
target_names=["Good Queries", "Infected Queries"],
)
print(report)
# Test the test dataset
y_pred = clf.predict(test_embeddings)
generate_report(test_set["Prompt injection"], y_pred)
This will print the following report.
precision recall f1-score support
Good Queries 0.88 0.53 0.66 709
Infected Queries 0.88 0.98 0.93 2494
accuracy 0.88 3203
macro avg 0.88 0.76 0.79 3203
weighted avg 0.88 0.88 0.87 3203
The classification report provides details about the model’s performance, including:
- Precision: Precision measures the proportion of true positive predictions among all positive predictions. A precision of 0.88 for both classes indicates that when the model predicts a query as good or infected, it is correct 88% of the time.
- Recall: Recall (or Sensitivity) measures the proportion of actual positives correctly predicted by the model. The recall for “Good Queries” is relatively low at 0.53, which means the model is missing a significant number of good queries (51% are not identified correctly). In contrast, the recall for “Infected Queries” is high at 0.98, indicating that the model is very effective at identifying infected queries.
- F1 score: The F1-score is the harmonic mean of precision and recall, providing a balance between the two. The F1-score for “Good Queries” is 0.66, which is acceptable but indicates room for improvement, especially given the low recall. The F1-score for “Infected Queries” is strong at 0.93, suggesting that the model performs very well for this class.
- Accuracy: Overall accuracy is 0.88, meaning that 88% of the total predictions made by the model are correct. However, this metric can be misleading in imbalanced datasets, as it does not reflect the model’s ability to predict each class accurately.
In the next post we will look at how we can improve Recall for Good Queries.
I am building a course on how to build production apps using LLMs. We will cover topics like prompt engineering, RAG, search, testing and evals, fine tuning, feedback analysis, and agents. You can register now and get 50% discount. Register using form – https://forms.gle/twuVNs9SeHzMt8q68
Discover more from Shekhar Gulati
Subscribe to get the latest posts sent to your email.