Spaces:
Build error
Build error
File size: 5,273 Bytes
28c256d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Iterator, List, Optional, Sequence, Union
from mmengine.dataset import pseudo_collate
from mmengine.registry import EVALUATOR, METRICS
from mmengine.structures import BaseDataElement
from .metric import BaseMetric
@EVALUATOR.register_module()
class Evaluator:
"""Wrapper class to compose multiple :class:`BaseMetric` instances.
Args:
metrics (dict or BaseMetric or Sequence): The config of metrics.
"""
def __init__(self, metrics: Union[dict, BaseMetric, Sequence]):
self._dataset_meta: Optional[dict] = None
if not isinstance(metrics, Sequence):
metrics = [metrics]
self.metrics: List[BaseMetric] = []
for metric in metrics:
if isinstance(metric, dict):
self.metrics.append(METRICS.build(metric))
else:
self.metrics.append(metric)
@property
def dataset_meta(self) -> Optional[dict]:
"""Optional[dict]: Meta info of the dataset."""
return self._dataset_meta
@dataset_meta.setter
def dataset_meta(self, dataset_meta: dict) -> None:
"""Set the dataset meta info to the evaluator and it's metrics."""
self._dataset_meta = dataset_meta
for metric in self.metrics:
metric.dataset_meta = dataset_meta
def process(self,
data_samples: Sequence[BaseDataElement],
data_batch: Optional[Any] = None):
"""Convert ``BaseDataSample`` to dict and invoke process method of each
metric.
Args:
data_samples (Sequence[BaseDataElement]): predictions of the model,
and the ground truth of the validation set.
data_batch (Any, optional): A batch of data from the dataloader.
"""
_data_samples = []
for data_sample in data_samples:
if isinstance(data_sample, BaseDataElement):
_data_samples.append(data_sample.to_dict())
else:
_data_samples.append(data_sample)
for metric in self.metrics:
metric.process(data_batch, _data_samples)
def evaluate(self, size: int) -> dict:
"""Invoke ``evaluate`` method of each metric and collect the metrics
dictionary.
Args:
size (int): Length of the entire validation dataset. When batch
size > 1, the dataloader may pad some data samples to make
sure all ranks have the same length of dataset slice. The
``collect_results`` function will drop the padded data based on
this size.
Returns:
dict: Evaluation results of all metrics. The keys are the names
of the metrics, and the values are corresponding results.
"""
metrics = {}
for metric in self.metrics:
_results = metric.evaluate(size)
# Check metric name conflicts
for name in _results.keys():
if name in metrics:
raise ValueError(
'There are multiple evaluation results with the same '
f'metric name {name}. Please make sure all metrics '
'have different prefixes.')
metrics.update(_results)
return metrics
def offline_evaluate(self,
data_samples: Sequence,
data: Optional[Sequence] = None,
chunk_size: int = 1):
"""Offline evaluate the dumped predictions on the given data .
Args:
data_samples (Sequence): All predictions and ground truth of the
model and the validation set.
data (Sequence, optional): All data of the validation set.
chunk_size (int): The number of data samples and predictions to be
processed in a batch.
"""
# support chunking iterable objects
def get_chunks(seq: Iterator, chunk_size=1):
stop = False
while not stop:
chunk = []
for _ in range(chunk_size):
try:
chunk.append(next(seq))
except StopIteration:
stop = True
break
if chunk:
yield chunk
if data is not None:
assert len(data_samples) == len(data), (
'data_samples and data should have the same length, but got '
f'data_samples length: {len(data_samples)} '
f'data length: {len(data)}')
data = get_chunks(iter(data), chunk_size)
size = 0
for output_chunk in get_chunks(iter(data_samples), chunk_size):
if data is not None:
data_chunk = pseudo_collate(next(data)) # type: ignore
else:
data_chunk = None
size += len(output_chunk)
self.process(output_chunk, data_chunk)
return self.evaluate(size)
|