Model Description
The purpose of our trained Random Forest models is to identify malicious prompts given the prompt embeddings derived from OpenAI, OctoAI, and MiniLM. The models are trained with 373,598 benign and malicious prompts. We split this dataset into 80% training and 20% test sets. To ensure equal proportion of the malicious and benign labels across splits, we use stratified sampling.
Embeddings consist of fixed-length numerical representations. For example, OpenAI generates an embedding vector consisting of 1,536 floating-point numbers for each prompt. Similarly, the embedding datasets for OctoAI and MiniLM consist of 1,027 and 387 features, respectively.
Model Evaluation
The binary classification performance of embedding-based random forest models is shared below:
Embedding | Precision | Recall | F1-score | AUC |
---|---|---|---|---|
OpenAI | 0.867 | 0.867 | 0.867 | 0.764 |
OctoAI | 0.849 | 0.853 | 0.851 | 0.731 |
MiniLM | 0.849 | 0.853 | 0.851 | 0.730 |
How To Use The Model
We have shared three versions of random forest models in this repository. We used the following embedding models: text-embedding-3-small
from OpenAI, and the open-source models gte-large
hosted on OctoAI, as well as the well-known all-MiniLM-L6-v2
. Therefore, you need to covert the prompts to its respective embeddings before querying the model to obtain its prediction: 0
for benign and 1
for malicous.
Citing This Work
Our implementation, along with the curated datasets used for evaluation, is available on GitHub. Additionaly, if you use our implementation for scientific research, you are highly encouraged to cite our paper.
@article{ayub2024embedding,
title={Embedding-based classifiers can detect prompt injection attacks},
author={Ayub, Md Ahsan and Majumdar, Subhabrata},
booktitle={CAMLIS},
year={2024}
}