|
--- |
|
library_name: transformers |
|
tags: |
|
- cybersecurity |
|
- mpnet |
|
- classification |
|
- fine-tuned |
|
--- |
|
|
|
# Model Card for MPNet Cybersecurity Classifier |
|
|
|
This is a fine-tuned MPNet model specialized for classifying cybersecurity threat groups based on textual descriptions of their tactics and techniques. |
|
|
|
## Model Details |
|
|
|
### Model Description |
|
|
|
This model is a fine-tuned MPNet classifier specialized in categorizing cybersecurity threat groups based on textual descriptions of their tactics, techniques, and procedures (TTPs). |
|
|
|
- **Developed by:** Dženan Hamzić |
|
- **Model type:** Transformer-based classification model (MPNet) |
|
- **Language(s) (NLP):** English |
|
- **License:** Apache-2.0 |
|
- **Finetuned from model:** microsoft/mpnet-base (with intermediate MLM fine-tuning) |
|
|
|
### Model Sources |
|
|
|
- **Base Model:** [microsoft/mpnet-base](https://huggingface.co/microsoft/mpnet-base) |
|
|
|
## Uses |
|
|
|
### Direct Use |
|
|
|
This model classifies textual cybersecurity descriptions into known cybersecurity threat groups. |
|
|
|
### Downstream Use |
|
|
|
Integration into Cyber Threat Intelligence platforms, SOC incident analysis tools, and automated threat detection systems. |
|
|
|
### Out-of-Scope Use |
|
|
|
- General language tasks unrelated to cybersecurity |
|
- Tasks outside the cybersecurity domain |
|
|
|
## Bias, Risks, and Limitations |
|
|
|
This model specializes in cybersecurity contexts. Predictions for unrelated contexts may be inaccurate. |
|
|
|
### Recommendations |
|
|
|
Always verify predictions with cybersecurity analysts before using in critical decision-making scenarios. |
|
|
|
## How to Get Started with the Model |
|
|
|
```python |
|
from transformers import AutoTokenizer, MPNetModel |
|
import torch |
|
|
|
model_name = "mpnet_classification_finetuned_v2" |
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
model = MPNetModel.from_pretrained(model_name) |
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
model.to(device) |
|
|
|
# Example inference |
|
sentence = "APT38 has used phishing emails with malicious links to distribute malware." |
|
inputs = tokenizer(sentence, return_tensors="pt", truncation=True, padding="max_length", max_length=128).to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs) |
|
cls_embedding = outputs.last_hidden_state[:, 0, :] |
|
predicted_class = classifier_model.classifier(cls_embedding).argmax(dim=1).cpu().item() |
|
|
|
print(f"Predicted GroupID: {predicted_class}") |
|
``` |
|
|
|
## Training Details |
|
|
|
### Training Data |
|
|
|
The training dataset comprises balanced textual descriptions of various cybersecurity threat groups' TTPs, augmented through synonym replacement to increase diversity. |
|
|
|
### Training Procedure |
|
|
|
- Fine-tuned from: MLM fine-tuned MPNet ("mpnet_mlm_cyber_finetuned-v2") |
|
- Epochs: 20 |
|
- Learning rate: 5e-6 |
|
- Batch size: 16 |
|
|
|
## Evaluation |
|
|
|
### Testing Data, Factors & Metrics |
|
|
|
- **Testing Data:** Stratified sample from original dataset. |
|
- **Metrics:** Accuracy, Weighted F1 Score |
|
|
|
### Results |
|
|
|
| Metric | Value | |
|
|------------------------|---------| |
|
| Classification Accuracy (Test) | 0.7161 | |
|
| Weighted F1 Score | [More Information Needed] | |
|
|
|
### Single Prediction Example |
|
|
|
```python |
|
|
|
# Create explicit mapping from numeric labels to original GroupIDs |
|
label_to_groupid = dict(enumerate(train_df["GroupID"].astype("category").cat.categories)) |
|
|
|
def predict_group(sentence): |
|
classifier_model.eval() |
|
encoding = tokenizer( |
|
sentence, |
|
truncation=True, |
|
padding="max_length", |
|
max_length=128, |
|
return_tensors="pt" |
|
) |
|
input_ids = encoding["input_ids"].to(device) |
|
attention_mask = encoding["attention_mask"].to(device) |
|
|
|
with torch.no_grad(): |
|
logits = classifier_model(input_ids, attention_mask) |
|
predicted_label = torch.argmax(logits, dim=1).cpu().item() |
|
|
|
|
|
# Explicitly convert numeric label to original GroupID |
|
predicted_groupid = label_to_groupid[predicted_label] |
|
return predicted_groupid |
|
|
|
sentence = "APT38 has used phishing emails with malicious links to distribute malware." |
|
predicted_class = predict_group(sentence) |
|
print(f"Predicted GroupID: {predicted_class}") # e.g., Predicted GroupID: G0081 |
|
``` |
|
|
|
## Environmental Impact |
|
|
|
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute). |
|
|
|
- **Hardware Type:** [To be filled by user] |
|
- **Hours used:** [To be filled by user] |
|
- **Cloud Provider:** [To be filled by user] |
|
- **Compute Region:** [To be filled by user] |
|
- **Carbon Emitted:** [To be filled by user] |
|
|
|
## Technical Specifications |
|
|
|
### Model Architecture |
|
|
|
- MPNet architecture with classification head (768 -> 512 -> num_labels) |
|
- Last 10 transformer layers fine-tuned explicitly |
|
|
|
## Environmental Impact |
|
|
|
Carbon emissions should be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute). |
|
|
|
## Model Card Authors |
|
|
|
- Dženan Hamzić |
|
|
|
## Model Card Contact |
|
|
|
- [More Information Needed] |