igorithm's picture
Upload 2 files
3928452 verified
raw
history blame
3.66 kB
from datasets import load_dataset
from sklearn.metrics import f1_score, accuracy_score
from transformers import (
AutoTokenizer,
AutoModelForSequenceClassification,
Trainer,
TrainingArguments,
)
from model import SciBertPaperClassifier
def encode_labels(example):
example["labels"] = label2id[example["category"]]
return example
def preprocess_function(examples):
texts = [
f"AUTHORS: {' '.join(a) if isinstance(a, list) else a} TITLE: {t} ABSTRACT: {ab}"
for a, t, ab in zip(
examples["authors"], examples["title"], examples["abstract"]
)
]
return tokenizer(texts, truncation=True, padding="max_length", max_length=256)
def compute_metrics(pred):
labels = pred.label_ids
logits = pred.predictions
preds = logits.argmax(-1)
return {
"accuracy": accuracy_score(labels, preds),
"f1": f1_score(labels, preds, average="weighted"),
}
if __name__ == "__main__":
print("DOWNLOADING DATASET...")
data_files = {"train": "arxiv_train.json", "test": "arxiv_test.json"}
dataset = load_dataset("json", data_files=data_files)
dataset["train"] = dataset["train"].shuffle(seed=42).select(range(100000))
print(f"DATA IS READY. TRAIN: {len(dataset['train'])}")
print("LABELING...")
unique_labels = sorted(set(example["category"] for example in dataset["train"]))
label2id = {label: idx for idx, label in enumerate(unique_labels)}
id2label = {idx: label for label, idx in label2id.items()}
dataset["train"] = dataset["train"].map(encode_labels)
split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
train_dataset = split_dataset["train"]
valid_dataset = split_dataset["test"]
print(f"TRAIN SET: {len(train_dataset)}, VALIDATION SET: {len(valid_dataset)}")
print("TOKENIZATION...")
model_name = "allenai/scibert_scivocab_uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
encoded_train = train_dataset.map(preprocess_function, batched=True, batch_size=32)
encoded_valid = valid_dataset.map(preprocess_function, batched=True, batch_size=32)
encoded_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
encoded_valid.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
print("TOKENIZATION COMPLETED")
print("DOWNLOADING MODEL...")
model = AutoModelForSequenceClassification.from_pretrained(
model_name,
num_labels=len(unique_labels),
id2label=id2label,
label2id=label2id,
)
training_args = TrainingArguments(
output_dir="./dataset_output",
report_to="none",
eval_strategy="steps",
eval_steps=100,
logging_steps=200,
disable_tqdm=True,
learning_rate=3e-5,
per_device_train_batch_size=32,
per_device_eval_batch_size=32,
num_train_epochs=2,
save_steps=200,
fp16=True,
remove_unused_columns=False,
)
print("LEARNING...")
trainer = Trainer(
model=model,
args=training_args,
train_dataset=encoded_train,
eval_dataset=encoded_valid,
compute_metrics=compute_metrics,
)
trainer.train()
print("LEARNING COMPLETED")
model.save_pretrained("trained_model")
tokenizer.save_pretrained("trained_model")
print("EVALUATION...")
final_metrics = trainer.evaluate()
print("METRICS:")
for key, value in final_metrics.items():
print(f"{key}: {value:.4f}")