Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import os.path as osp | |
| import shutil | |
| from collections import OrderedDict | |
| from typing import Dict, Optional, Sequence | |
| try: | |
| import cityscapesscripts.evaluation.evalPixelLevelSemanticLabeling as CSEval # noqa | |
| import cityscapesscripts.helpers.labels as CSLabels | |
| except ImportError: | |
| CSLabels = None | |
| CSEval = None | |
| import numpy as np | |
| from mmengine.dist import is_main_process, master_only | |
| from mmengine.evaluator import BaseMetric | |
| from mmengine.logging import MMLogger, print_log | |
| from mmengine.utils import mkdir_or_exist | |
| from PIL import Image | |
| from mmseg.registry import METRICS | |
| class CityscapesMetric(BaseMetric): | |
| """Cityscapes evaluation metric. | |
| Args: | |
| output_dir (str): The directory for output prediction | |
| ignore_index (int): Index that will be ignored in evaluation. | |
| Default: 255. | |
| format_only (bool): Only format result for results commit without | |
| perform evaluation. It is useful when you want to format the result | |
| to a specific format and submit it to the test server. | |
| Defaults to False. | |
| keep_results (bool): Whether to keep the results. When ``format_only`` | |
| is True, ``keep_results`` must be True. Defaults to False. | |
| collect_device (str): Device name used for collecting results from | |
| different ranks during distributed training. Must be 'cpu' or | |
| 'gpu'. Defaults to 'cpu'. | |
| prefix (str, optional): The prefix that will be added in the metric | |
| names to disambiguate homonymous metrics of different evaluators. | |
| If prefix is not provided in the argument, self.default_prefix | |
| will be used instead. Defaults to None. | |
| """ | |
| def __init__(self, | |
| output_dir: str, | |
| ignore_index: int = 255, | |
| format_only: bool = False, | |
| keep_results: bool = False, | |
| collect_device: str = 'cpu', | |
| prefix: Optional[str] = None, | |
| **kwargs) -> None: | |
| super().__init__(collect_device=collect_device, prefix=prefix) | |
| if CSEval is None: | |
| raise ImportError('Please run "pip install cityscapesscripts" to ' | |
| 'install cityscapesscripts first.') | |
| self.output_dir = output_dir | |
| self.ignore_index = ignore_index | |
| self.format_only = format_only | |
| if format_only: | |
| assert keep_results, ( | |
| 'When format_only is True, the results must be keep, please ' | |
| f'set keep_results as True, but got {keep_results}') | |
| self.keep_results = keep_results | |
| self.prefix = prefix | |
| if is_main_process(): | |
| mkdir_or_exist(self.output_dir) | |
| def __del__(self) -> None: | |
| """Clean up.""" | |
| if not self.keep_results: | |
| shutil.rmtree(self.output_dir) | |
| def process(self, data_batch: dict, data_samples: Sequence[dict]) -> None: | |
| """Process one batch of data and data_samples. | |
| The processed results should be stored in ``self.results``, which will | |
| be used to computed the metrics when all batches have been processed. | |
| Args: | |
| data_batch (dict): A batch of data from the dataloader. | |
| data_samples (Sequence[dict]): A batch of outputs from the model. | |
| """ | |
| mkdir_or_exist(self.output_dir) | |
| for data_sample in data_samples: | |
| pred_label = data_sample['pred_sem_seg']['data'][0].cpu().numpy() | |
| # when evaluating with official cityscapesscripts, | |
| # labelIds should be used | |
| pred_label = self._convert_to_label_id(pred_label) | |
| basename = osp.splitext(osp.basename(data_sample['img_path']))[0] | |
| png_filename = osp.abspath( | |
| osp.join(self.output_dir, f'{basename}.png')) | |
| output = Image.fromarray(pred_label.astype(np.uint8)).convert('P') | |
| output.save(png_filename) | |
| if self.format_only: | |
| # format_only always for test dataset without ground truth | |
| gt_filename = '' | |
| else: | |
| # when evaluating with official cityscapesscripts, | |
| # **_gtFine_labelIds.png is used | |
| gt_filename = data_sample['seg_map_path'].replace( | |
| 'labelTrainIds.png', 'labelIds.png') | |
| self.results.append((png_filename, gt_filename)) | |
| def compute_metrics(self, results: list) -> Dict[str, float]: | |
| """Compute the metrics from processed results. | |
| Args: | |
| results (list): Testing results of the dataset. | |
| Returns: | |
| dict[str: float]: Cityscapes evaluation results. | |
| """ | |
| logger: MMLogger = MMLogger.get_current_instance() | |
| if self.format_only: | |
| logger.info(f'results are saved to {osp.dirname(self.output_dir)}') | |
| return OrderedDict() | |
| msg = 'Evaluating in Cityscapes style' | |
| if logger is None: | |
| msg = '\n' + msg | |
| print_log(msg, logger=logger) | |
| eval_results = dict() | |
| print_log( | |
| f'Evaluating results under {self.output_dir} ...', logger=logger) | |
| CSEval.args.evalInstLevelScore = True | |
| CSEval.args.predictionPath = osp.abspath(self.output_dir) | |
| CSEval.args.evalPixelAccuracy = True | |
| CSEval.args.JSONOutput = False | |
| pred_list, gt_list = zip(*results) | |
| metric = dict() | |
| eval_results.update( | |
| CSEval.evaluateImgLists(pred_list, gt_list, CSEval.args)) | |
| metric['averageScoreCategories'] = eval_results[ | |
| 'averageScoreCategories'] | |
| metric['averageScoreInstCategories'] = eval_results[ | |
| 'averageScoreInstCategories'] | |
| return metric | |
| def _convert_to_label_id(result): | |
| """Convert trainId to id for cityscapes.""" | |
| if isinstance(result, str): | |
| result = np.load(result) | |
| result_copy = result.copy() | |
| for trainId, label in CSLabels.trainId2label.items(): | |
| result_copy[result == trainId] = label.id | |
| return result_copy | |