Spaces:
Build error
Build error
# Copyright (c) Meta Platforms, Inc. and affiliates. | |
# All rights reserved. | |
# | |
# This source code is licensed under the license found in the | |
# LICENSE file in the root directory of this source tree. | |
from typing import Any, Iterator, List, Optional, Sequence, Union | |
from mmengine.dataset import pseudo_collate | |
from mmengine.registry import EVALUATOR, METRICS | |
from mmengine.structures import BaseDataElement | |
from .metric import BaseMetric | |
class Evaluator: | |
"""Wrapper class to compose multiple :class:`BaseMetric` instances. | |
Args: | |
metrics (dict or BaseMetric or Sequence): The config of metrics. | |
""" | |
def __init__(self, metrics: Union[dict, BaseMetric, Sequence]): | |
self._dataset_meta: Optional[dict] = None | |
if not isinstance(metrics, Sequence): | |
metrics = [metrics] | |
self.metrics: List[BaseMetric] = [] | |
for metric in metrics: | |
if isinstance(metric, dict): | |
self.metrics.append(METRICS.build(metric)) | |
else: | |
self.metrics.append(metric) | |
def dataset_meta(self) -> Optional[dict]: | |
"""Optional[dict]: Meta info of the dataset.""" | |
return self._dataset_meta | |
def dataset_meta(self, dataset_meta: dict) -> None: | |
"""Set the dataset meta info to the evaluator and it's metrics.""" | |
self._dataset_meta = dataset_meta | |
for metric in self.metrics: | |
metric.dataset_meta = dataset_meta | |
def process(self, | |
data_samples: Sequence[BaseDataElement], | |
data_batch: Optional[Any] = None): | |
"""Convert ``BaseDataSample`` to dict and invoke process method of each | |
metric. | |
Args: | |
data_samples (Sequence[BaseDataElement]): predictions of the model, | |
and the ground truth of the validation set. | |
data_batch (Any, optional): A batch of data from the dataloader. | |
""" | |
_data_samples = [] | |
for data_sample in data_samples: | |
if isinstance(data_sample, BaseDataElement): | |
_data_samples.append(data_sample.to_dict()) | |
else: | |
_data_samples.append(data_sample) | |
for metric in self.metrics: | |
metric.process(data_batch, _data_samples) | |
def evaluate(self, size: int) -> dict: | |
"""Invoke ``evaluate`` method of each metric and collect the metrics | |
dictionary. | |
Args: | |
size (int): Length of the entire validation dataset. When batch | |
size > 1, the dataloader may pad some data samples to make | |
sure all ranks have the same length of dataset slice. The | |
``collect_results`` function will drop the padded data based on | |
this size. | |
Returns: | |
dict: Evaluation results of all metrics. The keys are the names | |
of the metrics, and the values are corresponding results. | |
""" | |
metrics = {} | |
for metric in self.metrics: | |
_results = metric.evaluate(size) | |
# Check metric name conflicts | |
for name in _results.keys(): | |
if name in metrics: | |
raise ValueError( | |
'There are multiple evaluation results with the same ' | |
f'metric name {name}. Please make sure all metrics ' | |
'have different prefixes.') | |
metrics.update(_results) | |
return metrics | |
def offline_evaluate(self, | |
data_samples: Sequence, | |
data: Optional[Sequence] = None, | |
chunk_size: int = 1): | |
"""Offline evaluate the dumped predictions on the given data . | |
Args: | |
data_samples (Sequence): All predictions and ground truth of the | |
model and the validation set. | |
data (Sequence, optional): All data of the validation set. | |
chunk_size (int): The number of data samples and predictions to be | |
processed in a batch. | |
""" | |
# support chunking iterable objects | |
def get_chunks(seq: Iterator, chunk_size=1): | |
stop = False | |
while not stop: | |
chunk = [] | |
for _ in range(chunk_size): | |
try: | |
chunk.append(next(seq)) | |
except StopIteration: | |
stop = True | |
break | |
if chunk: | |
yield chunk | |
if data is not None: | |
assert len(data_samples) == len(data), ( | |
'data_samples and data should have the same length, but got ' | |
f'data_samples length: {len(data_samples)} ' | |
f'data length: {len(data)}') | |
data = get_chunks(iter(data), chunk_size) | |
size = 0 | |
for output_chunk in get_chunks(iter(data_samples), chunk_size): | |
if data is not None: | |
data_chunk = pseudo_collate(next(data)) # type: ignore | |
else: | |
data_chunk = None | |
size += len(output_chunk) | |
self.process(output_chunk, data_chunk) | |
return self.evaluate(size) | |