update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import logging | |
| import math | |
| from typing import Any, Callable, Dict, List, Optional, Union, overload | |
| import numpy as np | |
| from pandas import MultiIndex | |
| from pie_modules.utils import flatten_dict | |
| from pytorch_ie import Document, DocumentMetric | |
| from pytorch_ie.core.metric import T | |
| from pytorch_ie.utils.hydra import resolve_target | |
| from src.hydra_callbacks.save_job_return_value import to_py_obj | |
| logger = logging.getLogger(__name__) | |
| def get_num_total(targets: List[int], preds: List[float]): | |
| return len(targets) | |
| def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1): | |
| return len([v for v in targets if v == positive_idx]) | |
| def discretize(values: List[float], threshold: float) -> List[float]: ... | |
| def discretize(values: List[float], threshold: List[float]) -> Dict[Any, List[float]]: ... | |
| def discretize( | |
| values: List[float], threshold: Union[float, List[float], dict] | |
| ) -> Union[List[float], Dict[Any, List[float]]]: | |
| if isinstance(threshold, float): | |
| result = (np.array(values) >= threshold).astype(int).tolist() | |
| return result | |
| if isinstance(threshold, list): | |
| return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore | |
| if isinstance(threshold, dict): | |
| thresholds = ( | |
| np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist() | |
| ) | |
| return discretize(values, threshold=thresholds) | |
| raise TypeError(f"threshold has unknown type: {threshold}") | |
| def get_metric_func(name: str) -> Callable: | |
| if name.endswith("_curve"): | |
| from sklearn.metrics import auc | |
| base_func = resolve_target(name) | |
| def wrapper(targets: List[int], preds: List[float], **kwargs): | |
| x, y, thresholds = base_func(targets, preds, **kwargs) | |
| return auc(y, x) | |
| return wrapper | |
| else: | |
| return resolve_target(name) | |
| def bootstrap( | |
| metric_fn: Callable[[List[int], Union[List[int], List[float]]], float], | |
| targets: List[int], | |
| predictions: Union[List[int], List[float]], | |
| n: int = 1_000, | |
| random_state: int | None = None, | |
| alpha: float = 0.95, | |
| ) -> Dict[str, float]: | |
| """ | |
| Returns mean and a two–sided (1–alpha) bootstrap CI for any | |
| pair-wise classification or ranking metric. | |
| Parameters | |
| ---------- | |
| metric_fn Metric function taking (targets, prediction) lists. | |
| targets Ground-truth 0/1 labels. | |
| prediction Scores or hard predictions (same length as `targets`). | |
| n Number of bootstrap replicates (after skipping degenerate ones). | |
| random_state Seed for reproducibility. | |
| alpha Confidence level (default 0.95 → 95 % CI). | |
| Notes | |
| ----- | |
| * A replicate that contains only one class is discarded | |
| because many sklearn metrics are undefined in that case. | |
| * If all replicates are discarded an exception is raised. | |
| """ | |
| y = np.asarray(targets) | |
| yhat = np.asarray(predictions) | |
| if y.shape[0] != yhat.shape[0]: | |
| raise ValueError("`targets` and `prediction` must have the same length") | |
| rng = np.random.default_rng(random_state) | |
| idx = np.arange(y.shape[0]) | |
| vals_list: list[float] = [] | |
| while len(vals_list) < n: | |
| sample_idx = rng.choice(idx, size=idx.shape[0], replace=True) | |
| y_samp, yhat_samp = y[sample_idx], yhat[sample_idx] | |
| # skip all-positive or all-negative bootstrap samples | |
| if y_samp.min() == y_samp.max(): | |
| continue | |
| vals_list.append(metric_fn(y_samp.tolist(), yhat_samp.tolist())) | |
| if not vals_list: | |
| raise RuntimeError("No valid bootstrap replicate contained both classes.") | |
| vals = np.asarray(vals_list, dtype=float) | |
| lower = np.percentile(vals, (1 - alpha) / 2 * 100) | |
| upper = np.percentile(vals, (1 + alpha) / 2 * 100) | |
| return {"mean": float(vals.mean()), "low": float(lower), "high": float(upper)} | |
| class BinaryClassificationMetricsSKLearn(DocumentMetric): | |
| def __init__( | |
| self, | |
| metrics: Dict[str, str], | |
| layer: str, | |
| label: Optional[str] = None, | |
| thresholds: Optional[Dict[str, float]] = None, | |
| default_target_idx: int = 0, | |
| default_prediction_score: float = 0.0, | |
| show_as_markdown: bool = False, | |
| markdown_precision: int = 4, | |
| bootstrap: Optional[list[str]] = None, | |
| bootstrap_n: int = 1_000, | |
| bootstrap_random_state: int | None = None, | |
| bootstrap_alpha: float = 0.95, | |
| create_plots: bool = True, | |
| plots: Optional[Dict[str, str]] = None, | |
| ): | |
| self.metrics = {name: get_metric_func(metric) for name, metric in metrics.items()} | |
| self.thresholds = thresholds or {} | |
| thresholds_not_in_metrics = { | |
| name: t for name, t in self.thresholds.items() if name not in self.metrics | |
| } | |
| if len(thresholds_not_in_metrics) > 0: | |
| logger.warning( | |
| f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}" | |
| ) | |
| self.annotation_layer_name = layer | |
| self.annotation_label = label | |
| self.default_target_idx = default_target_idx | |
| self.default_prediction_score = default_prediction_score | |
| self.show_as_markdown = show_as_markdown | |
| self.markdown_precision = markdown_precision | |
| if create_plots: | |
| self.plots = { | |
| name: resolve_target(plot_func) for name, plot_func in (plots or {}).items() | |
| } | |
| else: | |
| self.plots = {} | |
| self.bootstrap = set(bootstrap or []) | |
| self.bootstrap_kwargs = { | |
| "n": bootstrap_n, | |
| "random_state": bootstrap_random_state, | |
| "alpha": bootstrap_alpha, | |
| } | |
| super().__init__() | |
| def reset(self) -> None: | |
| self._preds: List[float] = [] | |
| self._targets: List[int] = [] | |
| def _update(self, document: Document) -> None: | |
| annotation_layer = document[self.annotation_layer_name] | |
| target2idx = { | |
| ann: int(ann.score) | |
| for ann in annotation_layer | |
| if self.annotation_label is None or ann.label == self.annotation_label | |
| } | |
| prediction2score = { | |
| ann: ann.score | |
| for ann in annotation_layer.predictions | |
| if self.annotation_label is None or ann.label == self.annotation_label | |
| } | |
| all_args = set(target2idx) | set(prediction2score) | |
| all_targets: List[int] = [] | |
| all_predictions: List[float] = [] | |
| for args in all_args: | |
| target_idx = target2idx.get(args, self.default_target_idx) | |
| prediction_score = prediction2score.get(args, self.default_prediction_score) | |
| all_targets.append(target_idx) | |
| all_predictions.append(prediction_score) | |
| self._preds.extend(all_predictions) | |
| self._targets.extend(all_targets) | |
| def create_plots(self): | |
| from matplotlib import pyplot as plt | |
| # Get the number of metrics | |
| num_plots = len(self.plots) | |
| # Calculate rows and columns for subplots (aim for a square-like layout) | |
| ncols = math.ceil(math.sqrt(num_plots)) | |
| nrows = math.ceil(num_plots / ncols) | |
| # Create the subplots | |
| fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10)) | |
| # Flatten the ax_list if necessary (in case of multiple rows/columns) | |
| if num_plots > 1: | |
| ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary | |
| else: | |
| ax_list = [ax_list] | |
| # Create each plot | |
| for ax, (name, plot_func) in zip(ax_list, self.plots.items()): | |
| # Set the title for each subplot | |
| ax.set_title(name) | |
| plot_func(y_true=self._targets, y_pred=self._preds, ax=ax) | |
| # Adjust layout to avoid overlapping plots | |
| plt.tight_layout() | |
| plt.show() | |
| def _compute(self) -> T: | |
| if len(self.plots) > 0: | |
| self.create_plots() | |
| result = {} | |
| for name, metric in self.metrics.items(): | |
| if name in self.thresholds: | |
| preds_dict = discretize(values=self._preds, threshold=self.thresholds[name]) | |
| if isinstance(preds_dict, dict): | |
| metric_results = { | |
| t: metric(self._targets, t_preds) for t, t_preds in preds_dict.items() | |
| } | |
| # just get the max | |
| max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1]) | |
| result[f"{name}_threshold"] = max_t | |
| preds = discretize(values=self._preds, threshold=max_t) | |
| else: | |
| preds = preds_dict | |
| else: | |
| preds = self._preds | |
| if name in self.bootstrap: | |
| # bootstrap the metric | |
| result[name] = bootstrap( | |
| metric_fn=metric, | |
| targets=self._targets, | |
| predictions=preds, | |
| **self.bootstrap_kwargs, # type: ignore | |
| ) | |
| else: | |
| result[name] = metric(self._targets, preds) | |
| result = to_py_obj(result) | |
| if self.show_as_markdown: | |
| import pandas as pd | |
| result_flat = flatten_dict(result) | |
| series = pd.Series(result_flat) | |
| if isinstance(series.index, MultiIndex): | |
| if len(series.index.levels) > 1: | |
| # in fact, this is not a series anymore | |
| series = series.unstack(-1) | |
| else: | |
| series.index = series.index.get_level_values(0) | |
| logger.info( | |
| f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}" | |
| ) | |
| return result | |