ArneBinder's picture
https://github.com/ArneBinder/pie-document-level/pull/312
3133b5e verified
raw
history blame
6.13 kB
import logging
import math
from typing import Any, Callable, Dict, List, Optional, Union
import numpy as np
import torch
from pandas import MultiIndex
from pie_modules.documents import TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
from pytorch_ie import DocumentMetric
from pytorch_ie.core.metric import T
from pytorch_ie.utils.hydra import resolve_target
from torchmetrics import Metric, MetricCollection
from src.hydra_callbacks.save_job_return_value import to_py_obj
logger = logging.getLogger(__name__)
def get_num_total(targets: List[int], preds: List[float]):
return len(targets)
def get_num_positives(targets: List[int], preds: List[float], positive_idx: int = 1):
return len([v for v in targets if v == positive_idx])
def discretize(
values: List[float], threshold: Union[float, List[float], dict]
) -> Union[List[float], Dict[Any, List[float]]]:
if isinstance(threshold, float):
result = (np.array(values) >= threshold).astype(int).tolist()
return result
if isinstance(threshold, list):
return {t: discretize(values=values, threshold=t) for t in threshold} # type: ignore
if isinstance(threshold, dict):
thresholds = (
np.arange(threshold["start"], threshold["end"], threshold["step"]).round(4).tolist()
)
return discretize(values, threshold=thresholds)
raise TypeError(f"threshold has unknown type: {threshold}")
class CorefMetricsSKLearn(DocumentMetric):
DOCUMENT_TYPE = TextPairDocumentWithLabeledSpansAndBinaryCorefRelations
def __init__(
self,
metrics: Dict[str, str],
thresholds: Optional[Dict[str, float]] = None,
default_target_idx: int = 0,
default_prediction_score: float = 0.0,
show_as_markdown: bool = False,
markdown_precision: int = 4,
plot: bool = False,
):
self.metrics = {name: resolve_target(metric) for name, metric in metrics.items()}
self.thresholds = thresholds or {}
thresholds_not_in_metrics = {
name: t for name, t in self.thresholds.items() if name not in self.metrics
}
if len(thresholds_not_in_metrics) > 0:
logger.warning(
f"there are discretizing thresholds that do not have a metric: {thresholds_not_in_metrics}"
)
self.default_target_idx = default_target_idx
self.default_prediction_score = default_prediction_score
self.show_as_markdown = show_as_markdown
self.markdown_precision = markdown_precision
self.plot = plot
super().__init__()
def reset(self) -> None:
self._preds: List[float] = []
self._targets: List[int] = []
def _update(self, document: TextPairDocumentWithLabeledSpansAndBinaryCorefRelations) -> None:
target_args2idx = {
(rel.head, rel.tail): int(rel.score) for rel in document.binary_coref_relations
}
prediction_args2score = {
(rel.head, rel.tail): rel.score for rel in document.binary_coref_relations.predictions
}
all_args = set(target_args2idx) | set(prediction_args2score)
all_targets: List[int] = []
all_predictions: List[float] = []
for args in all_args:
target_idx = target_args2idx.get(args, self.default_target_idx)
prediction_score = prediction_args2score.get(args, self.default_prediction_score)
all_targets.append(target_idx)
all_predictions.append(prediction_score)
# prediction_scores = torch.tensor(all_predictions)
# target_indices = torch.tensor(all_targets)
# self.metrics.update(preds=prediction_scores, target=target_indices)
self._preds.extend(all_predictions)
self._targets.extend(all_targets)
def do_plot(self):
raise NotImplementedError()
from matplotlib import pyplot as plt
# Get the number of metrics
num_metrics = len(self.metrics)
# Calculate rows and columns for subplots (aim for a square-like layout)
ncols = math.ceil(math.sqrt(num_metrics))
nrows = math.ceil(num_metrics / ncols)
# Create the subplots
fig, ax_list = plt.subplots(nrows=nrows, ncols=ncols, figsize=(15, 10))
# Flatten the ax_list if necessary (in case of multiple rows/columns)
ax_list = ax_list.flatten().tolist() # Ensure it's a list, and flatten it if necessary
# Ensure that we pass exactly the number of axes required by metrics
ax_list = ax_list[:num_metrics]
# Plot the metrics using the list of axes
self.metrics.plot(ax=ax_list, together=False)
# Adjust layout to avoid overlapping plots
plt.tight_layout()
plt.show()
def _compute(self) -> T:
if self.plot:
self.do_plot()
result = {}
for name, metric in self.metrics.items():
if name in self.thresholds:
preds = discretize(values=self._preds, threshold=self.thresholds[name])
else:
preds = self._preds
if isinstance(preds, dict):
metric_results = {
t: metric(self._targets, t_preds) for t, t_preds in preds.items()
}
# just get the max
max_t, max_v = max(metric_results.items(), key=lambda k_v: k_v[1])
result[f"{name}-{max_t}"] = max_v
else:
result[name] = metric(self._targets, preds)
result = to_py_obj(result)
if self.show_as_markdown:
import pandas as pd
series = pd.Series(result)
if isinstance(series.index, MultiIndex):
if len(series.index.levels) > 1:
# in fact, this is not a series anymore
series = series.unstack(-1)
else:
series.index = series.index.get_level_values(0)
logger.info(
f"{self.current_split}\n{series.round(self.markdown_precision).to_markdown()}"
)
return result