Spaces:
Sleeping
Sleeping
""" | |
A script for a text sentiment analysis tool for the 🤗 Transformers Agent library. | |
""" | |
from transformers import Tool | |
from transformers.tools.base import get_default_device | |
from transformers import pipeline | |
from transformers import DistilBertTokenizerFast | |
from trainDistilBERT import DistilBertForMulticlassSequenceClassification | |
import torch | |
class SentAnalClassifierTool(Tool): | |
""" | |
A tool for sentiment analysis | |
""" | |
ckpt = "ongknsro/ACARISBERT-DistilBERT" | |
name = "text_sentiment_analyzer" | |
description = ( | |
"This is a tool that returns a sentiment label for a given text sequence. " | |
"It takes the raw text as input, and " | |
"returns a sentiment label as output." | |
) | |
inputs = ["text"] | |
outputs = ["text"] | |
def __init__(self, device=None, **hub_kwargs) -> None: | |
super().__init__() | |
self.device = device | |
self.pipeline = None | |
self.hub_kwargs = hub_kwargs | |
def setup(self): | |
if self.device is None: | |
self.device = get_default_device() | |
self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt) | |
self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device) | |
self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0) | |
self.is_initialized = True | |
def __call__(self, task: str): | |
if not self.is_initialized: | |
self.setup() | |
outputs = self.pipeline(task) | |
labels = [item["label"] for item in outputs[0]] | |
logits = [item["score"] for item in outputs[0]] | |
probs = torch.softmax(torch.tensor(logits), dim=0) | |
label = labels[torch.argmax(probs).item()] | |
return label |