File size: 1,880 Bytes
3e1d9f3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import sys
import logging
from typing import Dict, Any, Sequence
from transformers import EvalPrediction
from ...utils import decode_generate_ids
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), ],
)
class BaseComputeMetrics:
def __init__(self, preprocessor: Dict[str, Any]):
self.preprocessor = preprocessor
self.tokenizer = self.preprocessor['text']
def __call__(self, eval_preds: EvalPrediction) -> Dict[str, Any]:
preds, targets = eval_preds
logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}")
preds = decode_generate_ids(self.tokenizer, preds)
targets = decode_generate_ids(self.tokenizer, targets)
assert len(preds) == len(targets)
return self.calculate_metric(preds, targets)
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
correct = 0
failed = 0
target_failed = 0
for pred, target in zip(preds, targets):
extract_pred = self.extract_ans(pred)
extract_target = self.extract_ans(target)
if extract_target is None:
target_failed += 1
logger.warning(f"failed to extract ans from target. maybe the response string is truncated: {target}.")
continue
if extract_pred is None:
failed += 1
if extract_pred == extract_target:
correct += 1
return {
'accuracy': 1.0 * correct / len(targets),
'target_failed': target_failed,
'failed': failed,
}
def extract_ans(self, string: str):
raise NotImplementedError
|