import argparse import torch from datasets import load_from_disk from transformers import ( AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding, Trainer, TrainingArguments, ) argparser = argparse.ArgumentParser() argparser.add_argument("--model", type=str, required=True) argparser.add_argument("--output_dir", type=str, required=True) args = argparser.parse_args() tokenizer = AutoTokenizer.from_pretrained(args.model) def preprocess_function(examples): return tokenizer(examples["text"], truncation=True) train_data = load_from_disk("train_data") test_data = load_from_disk("test_data") tokenized_train_data = train_data.map(preprocess_function, batched=True) tokenized_test_data = test_data.map(preprocess_function, batched=True) data_collator = DataCollatorWithPadding(tokenizer=tokenizer) id2label = {0: "safe", 1: "jailbreak"} label2id = {"safe": 0, "jailbreak": 1} model = AutoModelForSequenceClassification.from_pretrained( args.model, num_labels=2, id2label=id2label, label2id=label2id ) def compute_metrics(eval_pred): predictions, labels = eval_pred return {"accuracy": (predictions == labels).mean()} def preprocess_logits_for_metrics(logits, labels): """ Original Trainer may have a memory leak. This is a workaround to avoid storing too many tensors that are not needed. """ pred_ids = torch.argmax(logits, dim=-1) return pred_ids, labels training_args = TrainingArguments( output_dir=args.output_dir, learning_rate=2e-5, per_device_train_batch_size=2, per_device_eval_batch_size=2, eval_accumulation_steps=16, eval_steps=500, num_train_epochs=1, weight_decay=0.01, evaluation_strategy="steps", save_strategy="epoch", load_best_model_at_end=True, push_to_hub=False, ) trainer = Trainer( model=model, args=training_args, train_dataset=tokenized_train_data, eval_dataset=tokenized_test_data, tokenizer=tokenizer, data_collator=data_collator, compute_metrics=compute_metrics, preprocess_logits_for_metrics=preprocess_logits_for_metrics, ) trainer.train()