RxnIM / mllm /dataset /utils /compute_metrics.py
CYF200127's picture
Upload 235 files
3e1d9f3 verified
raw
history blame
1.88 kB
import sys
import logging
from typing import Dict, Any, Sequence
from transformers import EvalPrediction
from ...utils import decode_generate_ids
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
handlers=[logging.StreamHandler(sys.stdout), ],
)
class BaseComputeMetrics:
def __init__(self, preprocessor: Dict[str, Any]):
self.preprocessor = preprocessor
self.tokenizer = self.preprocessor['text']
def __call__(self, eval_preds: EvalPrediction) -> Dict[str, Any]:
preds, targets = eval_preds
logger.warning(f"preds shape: {preds.shape}. targets shape: {targets.shape}")
preds = decode_generate_ids(self.tokenizer, preds)
targets = decode_generate_ids(self.tokenizer, targets)
assert len(preds) == len(targets)
return self.calculate_metric(preds, targets)
def calculate_metric(self, preds: Sequence[str], targets: Sequence[str]) -> Dict[str, Any]:
correct = 0
failed = 0
target_failed = 0
for pred, target in zip(preds, targets):
extract_pred = self.extract_ans(pred)
extract_target = self.extract_ans(target)
if extract_target is None:
target_failed += 1
logger.warning(f"failed to extract ans from target. maybe the response string is truncated: {target}.")
continue
if extract_pred is None:
failed += 1
if extract_pred == extract_target:
correct += 1
return {
'accuracy': 1.0 * correct / len(targets),
'target_failed': target_failed,
'failed': failed,
}
def extract_ans(self, string: str):
raise NotImplementedError