|
import argparse |
|
|
|
import torch |
|
from datasets import load_from_disk |
|
from transformers import ( |
|
AutoModelForSequenceClassification, |
|
AutoTokenizer, |
|
DataCollatorWithPadding, |
|
Trainer, |
|
TrainingArguments, |
|
) |
|
|
|
argparser = argparse.ArgumentParser() |
|
argparser.add_argument("--model", type=str, required=True) |
|
argparser.add_argument("--output_dir", type=str, required=True) |
|
args = argparser.parse_args() |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(args.model) |
|
|
|
|
|
def preprocess_function(examples): |
|
return tokenizer(examples["text"], truncation=True) |
|
|
|
|
|
train_data = load_from_disk("train_data") |
|
test_data = load_from_disk("test_data") |
|
tokenized_train_data = train_data.map(preprocess_function, batched=True) |
|
tokenized_test_data = test_data.map(preprocess_function, batched=True) |
|
|
|
|
|
data_collator = DataCollatorWithPadding(tokenizer=tokenizer) |
|
|
|
id2label = {0: "safe", 1: "jailbreak"} |
|
label2id = {"safe": 0, "jailbreak": 1} |
|
|
|
|
|
model = AutoModelForSequenceClassification.from_pretrained( |
|
args.model, num_labels=2, id2label=id2label, label2id=label2id |
|
) |
|
|
|
|
|
def compute_metrics(eval_pred): |
|
predictions, labels = eval_pred |
|
return {"accuracy": (predictions == labels).mean()} |
|
|
|
|
|
def preprocess_logits_for_metrics(logits, labels): |
|
""" |
|
Original Trainer may have a memory leak. |
|
This is a workaround to avoid storing too many tensors that are not needed. |
|
""" |
|
pred_ids = torch.argmax(logits, dim=-1) |
|
return pred_ids, labels |
|
|
|
|
|
training_args = TrainingArguments( |
|
output_dir=args.output_dir, |
|
learning_rate=2e-5, |
|
per_device_train_batch_size=2, |
|
per_device_eval_batch_size=2, |
|
eval_accumulation_steps=16, |
|
eval_steps=500, |
|
num_train_epochs=1, |
|
weight_decay=0.01, |
|
evaluation_strategy="steps", |
|
save_strategy="epoch", |
|
load_best_model_at_end=True, |
|
push_to_hub=False, |
|
) |
|
|
|
trainer = Trainer( |
|
model=model, |
|
args=training_args, |
|
train_dataset=tokenized_train_data, |
|
eval_dataset=tokenized_test_data, |
|
tokenizer=tokenizer, |
|
data_collator=data_collator, |
|
compute_metrics=compute_metrics, |
|
preprocess_logits_for_metrics=preprocess_logits_for_metrics, |
|
) |
|
|
|
trainer.train() |
|
|