Spaces:
Running
Running
import evaluate | |
import numpy as np | |
import wandb | |
from datasets import load_dataset | |
from transformers import ( | |
AutoModelForSequenceClassification, | |
AutoTokenizer, | |
DataCollatorWithPadding, | |
Trainer, | |
TrainingArguments, | |
) | |
def train_binary_classifier( | |
project_name: str, | |
entity_name: str, | |
dataset_repo: str = "geekyrakshit/prompt-injection-dataset", | |
model_name: str = "distilbert/distilbert-base-uncased", | |
learning_rate: float = 2e-5, | |
batch_size: int = 16, | |
num_epochs: int = 2, | |
weight_decay: float = 0.01, | |
): | |
wandb.init(project=project_name, entity=entity_name) | |
dataset = load_dataset(dataset_repo) | |
tokenizer = AutoTokenizer.from_pretrained(model_name) | |
def preprocess_function(examples): | |
return tokenizer(examples["prompt"], truncation=True) | |
tokenized_datasets = dataset.map(preprocess_function, batched=True) | |
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) | |
accuracy = evaluate.load("accuracy") | |
def compute_metrics(eval_pred): | |
predictions, labels = eval_pred | |
predictions = np.argmax(predictions, axis=1) | |
return accuracy.compute(predictions=predictions, references=labels) | |
id2label = {0: "SAFE", 1: "INJECTION"} | |
label2id = {"SAFE": 0, "INJECTION": 1} | |
model = AutoModelForSequenceClassification.from_pretrained( | |
model_name, | |
num_labels=2, | |
id2label=id2label, | |
label2id=label2id, | |
) | |
trainer = Trainer( | |
model=model, | |
args=TrainingArguments( | |
output_dir="binary-classifier", | |
learning_rate=learning_rate, | |
per_device_train_batch_size=batch_size, | |
per_device_eval_batch_size=batch_size, | |
num_train_epochs=num_epochs, | |
weight_decay=weight_decay, | |
eval_strategy="epoch", | |
save_strategy="epoch", | |
load_best_model_at_end=True, | |
push_to_hub=True, | |
report_to="wandb", | |
), | |
train_dataset=tokenized_datasets["train"], | |
eval_dataset=tokenized_datasets["test"], | |
processing_class=tokenizer, | |
data_collator=data_collator, | |
compute_metrics=compute_metrics, | |
) | |
trainer.train() | |