Spaces:
Runtime error
Runtime error
| import contextlib | |
| import datetime | |
| import os | |
| import logging | |
| import numpy as np | |
| from PIL import Image | |
| from detectron2.utils.file_io import PathManager | |
| logger = logging.getLogger(__name__) | |
| def load_sem_seg(gt_root, image_root, gt_ext="png", image_ext="jpg", meta = None): | |
| """ | |
| Load semantic segmentation datasets. All files under "gt_root" with "gt_ext" extension are | |
| treated as ground truth annotations and all files under "image_root" with "image_ext" extension | |
| as input images. Ground truth and input images are matched using file paths relative to | |
| "gt_root" and "image_root" respectively without taking into account file extensions. | |
| This works for COCO as well as some other datasets. | |
| Args: | |
| gt_root (str): full path to ground truth semantic segmentation files. Semantic segmentation | |
| annotations are stored as images with integer values in pixels that represent | |
| corresponding semantic labels. | |
| image_root (str): the directory where the input images are. | |
| gt_ext (str): file extension for ground truth annotations. | |
| image_ext (str): file extension for input images. | |
| Returns: | |
| list[dict]: | |
| a list of dicts in detectron2 standard format without instance-level | |
| annotation. | |
| Notes: | |
| 1. This function does not read the image and ground truth files. | |
| The results do not have the "image" and "sem_seg" fields. | |
| """ | |
| # We match input images with ground truth based on their relative filepaths (without file | |
| # extensions) starting from 'image_root' and 'gt_root' respectively. | |
| def file2id(folder_path, file_path): | |
| # extract relative path starting from `folder_path` | |
| image_id = os.path.normpath(os.path.relpath(file_path, start=folder_path)) | |
| # remove file extension | |
| image_id = os.path.splitext(image_id)[0] | |
| return image_id | |
| input_files = sorted( | |
| (os.path.join(image_root, f) for f in PathManager.ls(image_root) if f.endswith(image_ext)), | |
| key=lambda file_path: file2id(image_root, file_path), | |
| ) | |
| gt_files = sorted( | |
| (os.path.join(gt_root, f) for f in PathManager.ls(gt_root) if f.endswith(gt_ext)), | |
| key=lambda file_path: file2id(gt_root, file_path), | |
| ) | |
| assert len(gt_files) > 0, "No annotations found in {}.".format(gt_root) | |
| # Use the intersection, so that val2017_100 annotations can run smoothly with val2017 images | |
| if len(input_files) != len(gt_files): | |
| logger.warn( | |
| "Directory {} and {} has {} and {} files, respectively.".format( | |
| image_root, gt_root, len(input_files), len(gt_files) | |
| ) | |
| ) | |
| input_basenames = [os.path.basename(f)[: -len(image_ext)] for f in input_files] | |
| gt_basenames = [os.path.basename(f)[: -len(gt_ext)] for f in gt_files] | |
| intersect = list(set(input_basenames) & set(gt_basenames)) | |
| # sort, otherwise each worker may obtain a list[dict] in different order | |
| intersect = sorted(intersect) | |
| logger.warn("Will use their intersection of {} files.".format(len(intersect))) | |
| input_files = [os.path.join(image_root, f + image_ext) for f in intersect] | |
| gt_files = [os.path.join(gt_root, f + gt_ext) for f in intersect] | |
| logger.info( | |
| "Loaded {} images with semantic segmentation from {}".format(len(input_files), image_root) | |
| ) | |
| dataset_dicts = [] | |
| for (img_path, gt_path) in zip(input_files, gt_files): | |
| record = {} | |
| record["file_name"] = img_path | |
| record["sem_seg_file_name"] = gt_path | |
| record["meta"] = meta | |
| dataset_dicts.append(record) | |
| return dataset_dicts |