Model Description

The purpose of our trained Random Forest models is to identify malicious prompts given the prompt embeddings derived from OpenAI, OctoAI, and MiniLM. The models are trained with 373,598 benign and malicious prompts. We split this dataset into 80% training and 20% test sets. To ensure equal proportion of the malicious and benign labels across splits, we use stratified sampling.

Embeddings consist of fixed-length numerical representations. For example, OpenAI generates an embedding vector consisting of 1,536 floating-point numbers for each prompt. Similarly, the embedding datasets for OctoAI and MiniLM consist of 1,027 and 387 features, respectively.

Model Evaluation

The binary classification performance of embedding-based random forest models is shared below:

Embedding Precision Recall F1-score AUC
OpenAI 0.867 0.867 0.867 0.764
OctoAI 0.849 0.853 0.851 0.731
MiniLM 0.849 0.853 0.851 0.730

How To Use The Model

We have shared three versions of random forest models in this repository. We used the following embedding models: text-embedding-3-small from OpenAI, and the open-source models gte-large hosted on OctoAI, as well as the well-known all-MiniLM-L6-v2. Therefore, you need to covert the prompts to its respective embeddings before querying the model to obtain its prediction: 0 for benign and 1 for malicous.

Citing This Work

Our implementation, along with the curated datasets used for evaluation, is available on GitHub. Additionaly, if you use our implementation for scientific research, you are highly encouraged to cite our paper.

@article{ayub2024embedding,
  title={Embedding-based classifiers can detect prompt injection attacks},
  author={Ayub, Md Ahsan and Majumdar, Subhabrata},
  booktitle={CAMLIS},
  year={2024}
}
Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model is not currently available via any of the supported Inference Providers.
The model cannot be deployed to the HF Inference API: The model has no library tag.

Datasets used to train ahsanayub/malicious-prompts-detection-random-forest