File size: 2,149 Bytes
729b0f4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 |
import argparse
import torch
from datasets import load_from_disk
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
argparser = argparse.ArgumentParser()
argparser.add_argument("--model", type=str, required=True)
argparser.add_argument("--output_dir", type=str, required=True)
args = argparser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
train_data = load_from_disk("train_data")
test_data = load_from_disk("test_data")
tokenized_train_data = train_data.map(preprocess_function, batched=True)
tokenized_test_data = test_data.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
id2label = {0: "safe", 1: "jailbreak"}
label2id = {"safe": 0, "jailbreak": 1}
model = AutoModelForSequenceClassification.from_pretrained(
args.model, num_labels=2, id2label=id2label, label2id=label2id
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
return {"accuracy": (predictions == labels).mean()}
def preprocess_logits_for_metrics(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
training_args = TrainingArguments(
output_dir=args.output_dir,
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
eval_accumulation_steps=16,
eval_steps=500,
num_train_epochs=1,
weight_decay=0.01,
evaluation_strategy="steps",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_data,
eval_dataset=tokenized_test_data,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()
|