sijju's picture
Upload folder using huggingface_hub
729b0f4 verified
raw
history blame
2.15 kB
import argparse
import torch
from datasets import load_from_disk
from transformers import (
AutoModelForSequenceClassification,
AutoTokenizer,
DataCollatorWithPadding,
Trainer,
TrainingArguments,
)
argparser = argparse.ArgumentParser()
argparser.add_argument("--model", type=str, required=True)
argparser.add_argument("--output_dir", type=str, required=True)
args = argparser.parse_args()
tokenizer = AutoTokenizer.from_pretrained(args.model)
def preprocess_function(examples):
return tokenizer(examples["text"], truncation=True)
train_data = load_from_disk("train_data")
test_data = load_from_disk("test_data")
tokenized_train_data = train_data.map(preprocess_function, batched=True)
tokenized_test_data = test_data.map(preprocess_function, batched=True)
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
id2label = {0: "safe", 1: "jailbreak"}
label2id = {"safe": 0, "jailbreak": 1}
model = AutoModelForSequenceClassification.from_pretrained(
args.model, num_labels=2, id2label=id2label, label2id=label2id
)
def compute_metrics(eval_pred):
predictions, labels = eval_pred
return {"accuracy": (predictions == labels).mean()}
def preprocess_logits_for_metrics(logits, labels):
"""
Original Trainer may have a memory leak.
This is a workaround to avoid storing too many tensors that are not needed.
"""
pred_ids = torch.argmax(logits, dim=-1)
return pred_ids, labels
training_args = TrainingArguments(
output_dir=args.output_dir,
learning_rate=2e-5,
per_device_train_batch_size=2,
per_device_eval_batch_size=2,
eval_accumulation_steps=16,
eval_steps=500,
num_train_epochs=1,
weight_decay=0.01,
evaluation_strategy="steps",
save_strategy="epoch",
load_best_model_at_end=True,
push_to_hub=False,
)
trainer = Trainer(
model=model,
args=training_args,
train_dataset=tokenized_train_data,
eval_dataset=tokenized_test_data,
tokenizer=tokenizer,
data_collator=data_collator,
compute_metrics=compute_metrics,
preprocess_logits_for_metrics=preprocess_logits_for_metrics,
)
trainer.train()