Spaces:
Sleeping
Sleeping
from transformers import AutoModelForSequenceClassification, AutoTokenizer | |
import torch | |
class SciBertPaperClassifier: | |
def __init__(self, model_path="trained_model"): | |
self.model = AutoModelForSequenceClassification.from_pretrained(model_path) | |
self.tokenizer = AutoTokenizer.from_pretrained(model_path) | |
self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
self.model.to(self.device) | |
self.model.eval() | |
def __call__(self, inputs): | |
texts = [ | |
f"AUTHORS: {' '.join(authors) if isinstance(authors, list) else authors} " | |
f"TITLE: {paper['title']} ABSTRACT: {paper['abstract']}" | |
for paper in inputs | |
for authors in [paper.get("authors", "")] | |
] | |
inputs = self.tokenizer( | |
texts, truncation=True, padding=True, max_length=256, return_tensors="pt" | |
).to(self.device) | |
with torch.no_grad(): | |
outputs = self.model(**inputs) | |
probs = torch.nn.functional.softmax(outputs.logits, dim=-1) | |
scores, labels = torch.max(probs, dim=1) | |
return [ | |
[{"label": self.model.config.id2label[label.item()], "score": score.item()}] | |
for label, score in zip(labels, scores) | |
] | |
def __getstate__(self): | |
return self.__dict__ | |
def __setstate__(self, state): | |
self.__dict__ = state | |
self.model.to(self.device) | |
def get_model(): | |
return SciBertPaperClassifier() | |