File size: 3,663 Bytes
3928452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from datasets import load_dataset
from sklearn.metrics import f1_score, accuracy_score
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    Trainer,
    TrainingArguments,
)
from model import SciBertPaperClassifier

def encode_labels(example):
    example["labels"] = label2id[example["category"]]
    return example

def preprocess_function(examples):
    texts = [
        f"AUTHORS: {' '.join(a) if isinstance(a, list) else a} TITLE: {t} ABSTRACT: {ab}"
        for a, t, ab in zip(
            examples["authors"], examples["title"], examples["abstract"]
        )
    ]
    return tokenizer(texts, truncation=True, padding="max_length", max_length=256)

def compute_metrics(pred):
    labels = pred.label_ids
    logits = pred.predictions
    preds = logits.argmax(-1)
    return {
        "accuracy": accuracy_score(labels, preds),
        "f1": f1_score(labels, preds, average="weighted"),
    }

if __name__ == "__main__":
    print("DOWNLOADING DATASET...")
    data_files = {"train": "arxiv_train.json", "test": "arxiv_test.json"}
    dataset = load_dataset("json", data_files=data_files)

    dataset["train"] = dataset["train"].shuffle(seed=42).select(range(100000))
    print(f"DATA IS READY. TRAIN: {len(dataset['train'])}")

    print("LABELING...")
    unique_labels = sorted(set(example["category"] for example in dataset["train"]))
    label2id = {label: idx for idx, label in enumerate(unique_labels)}
    id2label = {idx: label for label, idx in label2id.items()}

    dataset["train"] = dataset["train"].map(encode_labels)

    split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
    train_dataset = split_dataset["train"]
    valid_dataset = split_dataset["test"]
    print(f"TRAIN SET: {len(train_dataset)}, VALIDATION SET: {len(valid_dataset)}")

    print("TOKENIZATION...")
    model_name = "allenai/scibert_scivocab_uncased"
    tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

    encoded_train = train_dataset.map(preprocess_function, batched=True, batch_size=32)
    encoded_valid = valid_dataset.map(preprocess_function, batched=True, batch_size=32)
    encoded_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    encoded_valid.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
    print("TOKENIZATION COMPLETED")

    print("DOWNLOADING MODEL...")
    model = AutoModelForSequenceClassification.from_pretrained(
        model_name,
        num_labels=len(unique_labels),
        id2label=id2label,
        label2id=label2id,
    )

    training_args = TrainingArguments(
        output_dir="./dataset_output",
        report_to="none",
        eval_strategy="steps",
        eval_steps=100,
        logging_steps=200,
        disable_tqdm=True,
        learning_rate=3e-5,
        per_device_train_batch_size=32,
        per_device_eval_batch_size=32,
        num_train_epochs=2,
        save_steps=200,
        fp16=True,
        remove_unused_columns=False,
    )

    print("LEARNING...")
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=encoded_train,
        eval_dataset=encoded_valid,
        compute_metrics=compute_metrics,
    )
    trainer.train()
    print("LEARNING COMPLETED")

    model.save_pretrained("trained_model")
    tokenizer.save_pretrained("trained_model")

    print("EVALUATION...")
    final_metrics = trainer.evaluate()
    print("METRICS:")
    for key, value in final_metrics.items():
        print(f"{key}: {value:.4f}")