File size: 1,536 Bytes
3928452
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
from transformers import AutoModelForSequenceClassification, AutoTokenizer
import torch


class SciBertPaperClassifier:
    def __init__(self, model_path="trained_model"):
        self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        self.model.to(self.device)
        self.model.eval()

    def __call__(self, inputs):
        texts = [
            f"AUTHORS: {' '.join(authors) if isinstance(authors, list) else authors} "
            f"TITLE: {paper['title']} ABSTRACT: {paper['abstract']}"
            for paper in inputs
            for authors in [paper.get("authors", "")]
        ]

        inputs = self.tokenizer(
            texts, truncation=True, padding=True, max_length=256, return_tensors="pt"
        ).to(self.device)

        with torch.no_grad():
            outputs = self.model(**inputs)

        probs = torch.nn.functional.softmax(outputs.logits, dim=-1)
        scores, labels = torch.max(probs, dim=1)

        return [
            [{"label": self.model.config.id2label[label.item()], "score": score.item()}]
            for label, score in zip(labels, scores)
        ]

    def __getstate__(self):
        return self.__dict__

    def __setstate__(self, state):
        self.__dict__ = state
        self.model.to(self.device)


def get_model():
    return SciBertPaperClassifier()