llm-studio / llm_studio /src /metrics /text_causal_regression_modeling_metrics.py
qinfeng722's picture
Upload 322 files
5caedb4 verified
raw
history blame
2.32 kB
import logging
from typing import Any, Dict, List, Tuple, Union
import numpy as np
import pandas as pd
from numpy.typing import NDArray
logger = logging.getLogger(__name__)
def mse_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
target = np.array(
[[float(t) for t in text.split(",")] for text in results["target_text"]]
)
predictions = np.array(results["predictions"])
if len(target) != len(predictions):
raise ValueError(
f"Length of target ({len(target)}) and predictions ({len(predictions)}) "
"should be the same."
)
if len(target) == 0:
raise ValueError("No data to calculate MSE score")
return ((target - predictions) ** 2).mean(axis=1).reshape(-1).astype("float")
def mae_score(
cfg: Any,
results: Dict,
val_df: pd.DataFrame,
raw_results: bool = False,
) -> Union[NDArray, Tuple[NDArray, List[str]]]:
target = np.array(
[[float(t) for t in text.split(",")] for text in results["target_text"]]
)
predictions = np.array(results["predictions"])
if len(target) != len(predictions):
raise ValueError(
f"Length of target ({len(target)}) and predictions ({len(predictions)}) "
"should be the same."
)
if len(target) == 0:
raise ValueError("No data to calculate MAE score")
return np.abs(target - predictions).mean(axis=1).reshape(-1).astype("float")
class Metrics:
"""
Metrics factory. Returns:
- metric value
- should it be maximized or minimized
- Reduce function
Maximized or minimized is needed for early stopping (saving best checkpoint)
Reduce function to generate a single metric value, usually "mean" or "none"
"""
_metrics = {
"MSE": (mse_score, "min", "mean"),
"MAE": (mae_score, "min", "mean"),
}
@classmethod
def names(cls) -> List[str]:
return sorted(cls._metrics.keys())
@classmethod
def get(cls, name: str) -> Any:
"""Access to Metrics.
Args:
name: metrics name
Returns:
A class to build the Metrics
"""
return cls._metrics.get(name, cls._metrics["MSE"])