user-friendly-metrics / user-friendly-metrics.py
Gil-Simas's picture
renaming
786af2b
# Copyright 2020 The HuggingFace Datasets Authors and the current dataset script contributor.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import datetime
import os
from tqdm import tqdm
import datasets
import evaluate
from seametrics.user_friendly.utils import payload_to_uf_metrics, UFM
from seametrics.payload import Payload
import wandb
_CITATION = """\
@InProceedings{huggingface:module,
title = {A great new module},
authors={huggingface, Inc.},
year={2020}
}\
@article{milan2016mot16,
title={MOT16: A benchmark for multi-object tracking},
author={Milan, Anton and Leal-Taix{\'e}, Laura and Reid, Ian and Roth, Stefan and Schindler, Konrad},
journal={arXiv preprint arXiv:1603.00831},
year={2016}
}
"""
_DESCRIPTION = """\
The MOT Metrics module is designed to evaluate multi-object tracking (MOT)
algorithms by computing various metrics based on predicted and ground truth bounding
boxes. It serves as a crucial tool in assessing the performance of MOT systems,
aiding in the iterative improvement of tracking algorithms."""
_KWARGS_DESCRIPTION = """
Calculates how good are predictions given some references, using certain scores
Args:
predictions: list of predictions to score. Each predictions
should be a string with tokens separated by spaces.
references: list of reference for each prediction. Each
reference should be a string with tokens separated by spaces.
max_iou (`float`, *optional*):
If specified, this is the minimum Intersection over Union (IoU) threshold to consider a detection as a true positive.
Default is 0.5.
"""
@evaluate.utils.file_utils.add_start_docstrings(_DESCRIPTION, _KWARGS_DESCRIPTION)
class UserFriendlyMetrics(evaluate.Metric):
"""TODO: Short description of my evaluation module."""
def __init__(
self,
iou_threshold: float = 1e-10,
recognition_thresholds=[0.3, 0.5, 0.8],
filter_dict={"name": "area", "ranges": [("all", [0, 1e5**2])]},
**kwargs):
super().__init__(**kwargs)
# save parameters for later
self.iou_threshold = iou_threshold
self.filter_dict = filter_dict
self.recognition_thresholds = recognition_thresholds
self.metric = UFM(iou_threshold, recognition_thresholds)
def _info(self):
# TODO: Specifies the evaluate.EvaluationModuleInfo object
return evaluate.MetricInfo(
# This is the description that will appear on the modules page.
module_type="metric",
description=_DESCRIPTION,
citation=_CITATION,
inputs_description=_KWARGS_DESCRIPTION,
# This defines the format of each prediction and reference
features=datasets.Features({
"predictions": datasets.Sequence(
datasets.Sequence(datasets.Value("float"))
),
"references": datasets.Features({ "all":
datasets.Sequence(datasets.Sequence(datasets.Value("float")))}
)
}), #couldn't get this to work
# Additional links to the codebase or references
codebase_urls=["http://github.com/path/to/codebase/of/new_module"],
reference_urls=["http://path.to.reference.url/new_module"],
)
def _download_and_prepare(self, dl_manager):
"""Optional: download external resources useful to compute the scores"""
# TODO: Download external resources if needed
pass
def _compute(
self,
predictions,
references,
):
results = {}
filter_ranges = self.filter_dict["ranges"]
for filter_range in filter_ranges:
filter_range_name = filter_range[0]
range_results = {}
for sequence_predictions, sequence_references in zip(predictions, references):
sequence_range_results = self.metric.calculate(
sequence_predictions,
sequence_references[filter_range_name],
)
range_results = sum_dicts(range_results, sequence_range_results)
results[filter_range_name] = self.metric.derive_scores(range_results, self.recognition_thresholds)
return results
def compute_from_payload(self,
payload: Payload,
):
results = {}
for model_name in payload.models:
results[model_name] = {"overall": {}, "per_sequence": {}}
# per-sequence loop
progress_bar = tqdm(payload.sequences.items())
for seq_name, sequence in progress_bar:
progress_bar.set_description(f"Getting sequence payload: {seq_name}")
# create new payload only with specific sequence and model
sequence_payload = Payload(
dataset=payload.dataset,
gt_field_name=payload.gt_field_name,
models=[model_name],
sequences={seq_name: sequence}
)
progress_bar.set_description(f"Processing sequence: {seq_name}")
predictions, references = payload_to_uf_metrics(sequence_payload, model_name=model_name, filter_dict=self.filter_dict)
results[model_name]["per_sequence"][seq_name] = self._compute(predictions=predictions, references=references)
# overall
model_payload = Payload(
dataset=payload.dataset,
gt_field_name=payload.gt_field_name,
models=[model_name],
sequences=payload.sequences
)
predictions, references = payload_to_uf_metrics(model_payload, model_name=model_name, filter_dict=self.filter_dict)
results[model_name]["overall"] = self._compute(predictions=predictions, references=references)
return results
def wandb(
self,
results,
wandb_section: str = None,
wandb_runs = None,
wandb_project="user_friendly_metrics",
log_plots: bool = True,
debug: bool = False,
log_per_sequence = False
):
"""
Logs metrics to Weights and Biases (wandb) for tracking and visualization, including categorized bar charts for overall metrics.
Args:
results (dict): Results dictionary with 'overall' and 'per_sequence' keys.
wandb_section (str, optional): W&B section for metric grouping. Defaults to None.
wandb_project (str, optional): The name of the wandb project. Defaults to 'user_friendly_metrics'.
log_plots (bool, optional): Generates categorized bar charts for overall metrics. Defaults to True.
debug (bool, optional): Logs detailed summaries and histories to the terminal console. Defaults to False.
"""
current_datetime = datetime.datetime.now()
formatted_datetime = current_datetime.strftime("%Y-%m-%d_%H-%M-%S")
wandb.login(key=os.getenv("WANDB_API_KEY"))
if wandb_runs is not None:
assert len(wandb_runs) == len(results), "runs and results must have the same length"
else:
wandb_runs = [f"{i}_{formatted_datetime}" for i in list(results.keys())]
for wandb_run_name, result in zip(wandb_runs, results.values()):
self.wandb_run(result = result,
wandb_run_name = wandb_run_name,
wandb_project = wandb_project,
debug = debug,
wandb_section = wandb_section,
log_plots = log_plots,
log_per_sequence = log_per_sequence)
def wandb_run(self, result, wandb_run_name, wandb_project, debug, wandb_section = None, log_plots = True, log_per_sequence = False):
run = wandb.init(
project = wandb_project,
name = wandb_run_name,
reinit = True,
settings = wandb.Settings(silent=not debug),
)
categories = {
"user_friendly_metrics": {
f"mostly_tracked_score_{str(threshold).replace('.', '_')}" for threshold in self.recognition_thresholds
},
"evaluation_metrics_dev": {
"recall",
},
"user_friendly_metrics_dev": {
f"mostly_tracked_count_{str(threshold).replace('.', '_')}" for threshold in self.recognition_thresholds
}.union("unique_object_count"),
"predictions_summary": {
"tp",
"fn",
},
}
chart_data = {key: [] for key in categories.keys()}
# Log overall metrics
if "overall" in result:
for metric, value in result["overall"]["all"].items():
log_key = (
f"{wandb_section}/overall/{metric}"
if wandb_section
else f"overall/{metric}"
)
run.log({log_key: value})
if debug:
print(f" {log_key} = {value}")
for category, metrics in categories.items():
if metric in metrics:
chart_data[category].append([metric, value])
print("----------------------------------------------------")
if log_plots:
for category, data in chart_data.items():
if data:
table_data = [[label, value] for label, value in data]
table = wandb.Table(data=table_data, columns=["metrics", "value"])
run.log(
{
f"{category}_bar_chart": wandb.plot.bar(
table,
"metrics",
"value",
title=f"{category.replace('_', ' ').title()}",
)
}
)
if log_per_sequence:
if "per_sequence" in result:
sorted_sequences = sorted(
result["per_sequence"].items(),
key=lambda x: next(iter(x[1].values()), {}).get("all", {}).get("recall", 0),
reverse=True, # Set to True for descending order
)
for sequence_name, sequence_data in sorted_sequences:
for metric, value in sequence_data["all"].items():
log_key = (
f"{wandb_section}/per_sequence/{sequence_name}/{metric}"
if wandb_section
else f"per_sequence/{sequence_name}/{metric}"
)
run.log({log_key: value})
if debug:
print(f" {log_key} = {value}")
print("----------------------------------------------------")
if debug:
print("\nDebug Mode: Logging Summary and History")
print(f"Results Summary:\n{result}")
print(f"WandB Settings:\n{run.settings}")
print("All metrics have been logged.")
run.finish()
def sum_dicts(*dicts):
"""
Sums multiple dictionaries with depth one. If keys overlap, their values are summed.
If keys are unique, they are simply included in the result.
Args:
*dicts: Any number of dictionaries to be summed.
Returns:
A single dictionary with the summed values.
"""
result = {}
for d in dicts:
for key, value in d.items():
if key in result:
result[key] += value
else:
result[key] = value
return result