igorithm commited on
Commit
30fc256
·
verified ·
1 Parent(s): 3928452

Upload 2 files

Browse files
category_classification/models/allenai__scibert_sci_vocab_uncased/model.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import AutoModelForSequenceClassification, AutoTokenizer
2
+ import torch
3
+
4
+
5
+ class SciBertPaperClassifier:
6
+ def __init__(self, model_path="trained_model"):
7
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
8
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
9
+ self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
10
+ self.model.to(self.device)
11
+ self.model.eval()
12
+
13
+ def __call__(self, inputs):
14
+ texts = [
15
+ f"AUTHORS: {' '.join(authors) if isinstance(authors, list) else authors} "
16
+ f"TITLE: {paper['title']} ABSTRACT: {paper['abstract']}"
17
+ for paper in inputs
18
+ for authors in [paper.get("authors", "")]
19
+ ]
20
+
21
+ inputs = self.tokenizer(
22
+ texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
23
+ ).to(self.device)
24
+
25
+ with torch.no_grad():
26
+ outputs = self.model(**inputs)
27
+
28
+ probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
29
+ scores, labels = torch.max(probs, dim=1)
30
+
31
+ return [
32
+ [{"label": self.model.config.id2label[label.item()], "score": score.item()}]
33
+ for label, score in zip(labels, scores)
34
+ ]
35
+
36
+ def __getstate__(self):
37
+ return self.__dict__
38
+
39
+ def __setstate__(self, state):
40
+ self.__dict__ = state
41
+ self.model.to(self.device)
42
+
43
+
44
+ def get_model():
45
+ return SciBertPaperClassifier()
category_classification/models/allenai__scibert_sci_vocab_uncased/train.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import load_dataset
2
+ from sklearn.metrics import f1_score, accuracy_score
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
+ Trainer,
7
+ TrainingArguments,
8
+ )
9
+ from model import SciBertPaperClassifier
10
+
11
+ def encode_labels(example):
12
+ example["labels"] = label2id[example["category"]]
13
+ return example
14
+
15
+ def preprocess_function(examples):
16
+ texts = [
17
+ f"AUTHORS: {' '.join(a) if isinstance(a, list) else a} TITLE: {t} ABSTRACT: {ab}"
18
+ for a, t, ab in zip(
19
+ examples["authors"], examples["title"], examples["abstract"]
20
+ )
21
+ ]
22
+ return tokenizer(texts, truncation=True, padding="max_length", max_length=256)
23
+
24
+ def compute_metrics(pred):
25
+ labels = pred.label_ids
26
+ logits = pred.predictions
27
+ preds = logits.argmax(-1)
28
+ return {
29
+ "accuracy": accuracy_score(labels, preds),
30
+ "f1": f1_score(labels, preds, average="weighted"),
31
+ }
32
+
33
+ if __name__ == "__main__":
34
+ print("DOWNLOADING DATASET...")
35
+ data_files = {"train": "arxiv_train.json", "test": "arxiv_test.json"}
36
+ dataset = load_dataset("json", data_files=data_files)
37
+
38
+ dataset["train"] = dataset["train"].shuffle(seed=42).select(range(100000))
39
+ print(f"DATA IS READY. TRAIN: {len(dataset['train'])}")
40
+
41
+ print("LABELING...")
42
+ unique_labels = sorted(set(example["category"] for example in dataset["train"]))
43
+ label2id = {label: idx for idx, label in enumerate(unique_labels)}
44
+ id2label = {idx: label for label, idx in label2id.items()}
45
+
46
+ dataset["train"] = dataset["train"].map(encode_labels)
47
+
48
+ split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
49
+ train_dataset = split_dataset["train"]
50
+ valid_dataset = split_dataset["test"]
51
+ print(f"TRAIN SET: {len(train_dataset)}, VALIDATION SET: {len(valid_dataset)}")
52
+
53
+ print("TOKENIZATION...")
54
+ model_name = "allenai/scibert_scivocab_uncased"
55
+ tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)
56
+
57
+ encoded_train = train_dataset.map(preprocess_function, batched=True, batch_size=32)
58
+ encoded_valid = valid_dataset.map(preprocess_function, batched=True, batch_size=32)
59
+ encoded_train.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
60
+ encoded_valid.set_format("torch", columns=["input_ids", "attention_mask", "labels"])
61
+ print("TOKENIZATION COMPLETED")
62
+
63
+ print("DOWNLOADING MODEL...")
64
+ model = AutoModelForSequenceClassification.from_pretrained(
65
+ model_name,
66
+ num_labels=len(unique_labels),
67
+ id2label=id2label,
68
+ label2id=label2id,
69
+ )
70
+
71
+ training_args = TrainingArguments(
72
+ output_dir="./dataset_output",
73
+ report_to="none",
74
+ eval_strategy="steps",
75
+ eval_steps=100,
76
+ logging_steps=200,
77
+ disable_tqdm=True,
78
+ learning_rate=3e-5,
79
+ per_device_train_batch_size=32,
80
+ per_device_eval_batch_size=32,
81
+ num_train_epochs=2,
82
+ save_steps=200,
83
+ fp16=True,
84
+ remove_unused_columns=False,
85
+ )
86
+
87
+ print("LEARNING...")
88
+ trainer = Trainer(
89
+ model=model,
90
+ args=training_args,
91
+ train_dataset=encoded_train,
92
+ eval_dataset=encoded_valid,
93
+ compute_metrics=compute_metrics,
94
+ )
95
+ trainer.train()
96
+ print("LEARNING COMPLETED")
97
+
98
+ model.save_pretrained("trained_model")
99
+ tokenizer.save_pretrained("trained_model")
100
+
101
+ print("EVALUATION...")
102
+ final_metrics = trainer.evaluate()
103
+ print("METRICS:")
104
+ for key, value in final_metrics.items():
105
+ print(f"{key}: {value:.4f}")