Spaces:
Sleeping
Sleeping
| """ | |
| A script for a text sentiment analysis tool for the 🤗 Transformers Agent library. | |
| """ | |
| from transformers import Tool | |
| from transformers.tools.base import get_default_device | |
| from transformers import pipeline | |
| from transformers import DistilBertTokenizerFast | |
| from trainDistilBERT import DistilBertForMulticlassSequenceClassification | |
| import torch | |
| class SentAnalClassifierTool(Tool): | |
| """ | |
| A tool for sentiment analysis | |
| """ | |
| ckpt = "ongknsro/ACARISBERT-DistilBERT" | |
| name = "text_sentiment_analyzer" | |
| description = ( | |
| "This is a tool that returns a sentiment label for a given text sequence. " | |
| "It takes the raw text as input, and " | |
| "returns a sentiment label as output." | |
| ) | |
| inputs = ["text"] | |
| outputs = ["text"] | |
| def __init__(self, device=None, **hub_kwargs) -> None: | |
| super().__init__() | |
| self.device = device | |
| self.pipeline = None | |
| self.hub_kwargs = hub_kwargs | |
| def setup(self): | |
| if self.device is None: | |
| self.device = get_default_device() | |
| self.tokenizer = DistilBertTokenizerFast.from_pretrained(self.ckpt) | |
| self.model = DistilBertForMulticlassSequenceClassification.from_pretrained(self.ckpt).to(self.device) | |
| self.pipeline = pipeline("sentiment-analysis", model=self.model, tokenizer=self.tokenizer, top_k=None, device=0) | |
| self.is_initialized = True | |
| def __call__(self, task: str): | |
| if not self.is_initialized: | |
| self.setup() | |
| outputs = self.pipeline(task) | |
| labels = [item["label"] for item in outputs[0]] | |
| logits = [item["score"] for item in outputs[0]] | |
| probs = torch.softmax(torch.tensor(logits), dim=0) | |
| label = labels[torch.argmax(probs).item()] | |
| return label |