Spaces:
Sleeping
Sleeping
from textattack.models.wrappers import HuggingFaceModelWrapper | |
class TADModelWrapper(HuggingFaceModelWrapper): | |
"""Transformers sentiment analysis pipeline returns a list of responses | |
like | |
[{'label': 'POSITIVE', 'score': 0.7817379832267761}] | |
We need to convert that to a format TextAttack understands, like | |
[[0.218262017, 0.7817379832267761] | |
""" | |
def __init__(self, model): | |
self.model = model # pipeline = pipeline | |
def __call__(self, text_inputs, **kwargs): | |
outputs = [] | |
for text_input in text_inputs: | |
raw_outputs = self.model.infer(text_input, print_result=False, **kwargs) | |
outputs.append(raw_outputs["probs"]) | |
return outputs | |