|
import sys |
|
import logging |
|
from typing import Dict, Any, Sequence |
|
|
|
from transformers import EvalPrediction |
|
|
|
from ...utils import decode_generate_ids |
|
|
|
logger = logging.getLogger(__name__) |
|
logger.setLevel(logging.INFO) |
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%m/%d/%Y %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout), ], |
|
) |
|
|
|
|
|
class BaseComputeMetrics: |
|
def __init__(self, preprocessor: Dict[str, Any]): |
|
self.preprocessor = preprocessor |
|
self.tokenizer = self.preprocessor['text'] |
|
|
|
def __call__(self, eval_preds: EvalPrediction) -> Dict[str, Any]: |
|
preds, targets = eval_preds |
|
logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}") |
|
preds = decode_generate_ids(self.tokenizer, preds) |
|
targets = decode_generate_ids(self.tokenizer, targets) |
|
assert len(preds) == len(targets) |
|
return self.calculate_metric(preds, targets) |
|
|
|
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]: |
|
correct = 0 |
|
failed = 0 |
|
target_failed = 0 |
|
for pred, target in zip(preds, targets): |
|
extract_pred = self.extract_ans(pred) |
|
extract_target = self.extract_ans(target) |
|
if extract_target is None: |
|
target_failed += 1 |
|
logger.warning(f"failed to extract ans from target. maybe the response string is truncated: {target}.") |
|
continue |
|
if extract_pred is None: |
|
failed += 1 |
|
if extract_pred == extract_target: |
|
correct += 1 |
|
return { |
|
'accuracy': 1.0 * correct / len(targets), |
|
'target_failed': target_failed, |
|
'failed': failed, |
|
} |
|
|
|
def extract_ans(self, string: str): |
|
raise NotImplementedError |
|
|