Upload NLIScorer
Browse files- pipeline.py +9 -14
pipeline.py
CHANGED
@@ -363,12 +363,13 @@ class NLIScorer(Pipeline):
|
|
363 |
def _sanitize_parameters(self, **kwargs):
|
364 |
preprocess_kwargs = {}
|
365 |
postprocess_kwargs = {}
|
366 |
-
|
367 |
-
|
|
|
|
|
368 |
return preprocess_kwargs, {}, postprocess_kwargs
|
369 |
|
370 |
-
def preprocess(self, inputs):
|
371 |
-
task_name = inputs.pop("task_type")
|
372 |
TaskClass = TASK_CLASSES[task_name]
|
373 |
task_class = TaskClass(tokenizer=self.tokenizer, **inputs)
|
374 |
return task_class.as_model_inputs
|
@@ -377,17 +378,11 @@ class NLIScorer(Pipeline):
|
|
377 |
outputs = self.model(**model_inputs)
|
378 |
return outputs
|
379 |
|
380 |
-
def postprocess(self, model_outputs):
|
381 |
pos_scores = model_outputs["logits"].softmax(-1)[0][1]
|
382 |
-
|
383 |
-
|
384 |
-
def __call__(self, inputs, **kwargs):
|
385 |
-
task_name = inputs.get("task_type")
|
386 |
-
task_threshold = TASK_THRESHOLDS[task_name]
|
387 |
-
outputs = super().__call__(inputs, **kwargs)
|
388 |
-
best_class = int(outputs["score"] > task_threshold)
|
389 |
if best_class == 1:
|
390 |
-
score =
|
391 |
else:
|
392 |
-
score = 1 -
|
393 |
return {"score": score.item(), "label": best_class}
|
|
|
363 |
def _sanitize_parameters(self, **kwargs):
|
364 |
preprocess_kwargs = {}
|
365 |
postprocess_kwargs = {}
|
366 |
+
if "task_name" in kwargs:
|
367 |
+
postprocess_kwargs["task_name"] = kwargs["task_name"]
|
368 |
+
if "threshold" in kwargs:
|
369 |
+
postprocess_kwargs["threshold"] = kwargs["threshold"]
|
370 |
return preprocess_kwargs, {}, postprocess_kwargs
|
371 |
|
372 |
+
def preprocess(self, inputs, task_name):
|
|
|
373 |
TaskClass = TASK_CLASSES[task_name]
|
374 |
task_class = TaskClass(tokenizer=self.tokenizer, **inputs)
|
375 |
return task_class.as_model_inputs
|
|
|
378 |
outputs = self.model(**model_inputs)
|
379 |
return outputs
|
380 |
|
381 |
+
def postprocess(self, model_outputs, threshold=0.5):
|
382 |
pos_scores = model_outputs["logits"].softmax(-1)[0][1]
|
383 |
+
best_class = int(pos_scores > threshold)
|
|
|
|
|
|
|
|
|
|
|
|
|
384 |
if best_class == 1:
|
385 |
+
score = pos_scores
|
386 |
else:
|
387 |
+
score = 1 - pos_scores
|
388 |
return {"score": score.item(), "label": best_class}
|