from datasets import load_dataset from sklearn.metrics import f1_score, accuracy_score from transformers import ( AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments, ) from model import SciBertPaperClassifier def encode_labels(example): example["labels"] = label2id[example["category"]] return example def preprocess_function(examples): texts = [ f"AUTHORS: {' '.join(a) if isinstance(a, list) else a} TITLE: {t} ABSTRACT: {ab}" for a, t, ab in zip( examples["authors"], examples["title"], examples["abstract"] ) ] return tokenizer(texts, truncation=True, padding="max_length", max_length=256) def compute_metrics(pred): labels = pred.label_ids logits = pred.predictions preds = logits.argmax(-1) return { "accuracy": accuracy_score(labels, preds), "f1": f1_score(labels, preds, average="weighted"), } if __name__ == "__main__": print("DOWNLOADING DATASET...") data_files = {"train": "arxiv_train.json", "test": "arxiv_test.json"} dataset = load_dataset("json", data_files=data_files) dataset["train"] = dataset["train"].shuffle(seed=42).select(range(100000)) print(f"DATA IS READY. TRAIN: {len(dataset['train'])}") print("LABELING...") unique_labels = sorted(set(example["category"] for example in dataset["train"])) label2id = {label: idx for idx, label in enumerate(unique_labels)} id2label = {idx: label for label, idx in label2id.items()} dataset["train"] = dataset["train"].map(encode_labels) split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42) train_dataset = split_dataset["train"] valid_dataset = split_dataset["test"] print(f"TRAIN SET: {len(train_dataset)}, VALIDATION SET: {len(valid_dataset)}") print("TOKENIZATION...") model_name = "allenai/scibert_scivocab_uncased" tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True) encoded_train = train_dataset.map(preprocess_function, batched=True, batch_size=32) encoded_valid = valid_dataset.map(preprocess_function, batched=True, batch_size=32) encoded_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) encoded_valid.set_format("torch", columns=["input_ids", "attention_mask", "labels"]) print("TOKENIZATION COMPLETED") print("DOWNLOADING MODEL...") model = AutoModelForSequenceClassification.from_pretrained( model_name, num_labels=len(unique_labels), id2label=id2label, label2id=label2id, ) training_args = TrainingArguments( output_dir="./dataset_output", report_to="none", eval_strategy="steps", eval_steps=100, logging_steps=200, disable_tqdm=True, learning_rate=3e-5, per_device_train_batch_size=32, per_device_eval_batch_size=32, num_train_epochs=2, save_steps=200, fp16=True, remove_unused_columns=False, ) print("LEARNING...") trainer = Trainer( model=model, args=training_args, train_dataset=encoded_train, eval_dataset=encoded_valid, compute_metrics=compute_metrics, ) trainer.train() print("LEARNING COMPLETED") model.save_pretrained("trained_model") tokenizer.save_pretrained("trained_model") print("EVALUATION...") final_metrics = trainer.evaluate() print("METRICS:") for key, value in final_metrics.items(): print(f"{key}: {value:.4f}")