Spaces:
Running
on
Zero
Running
on
Zero
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
import bisect | |
import copy | |
import itertools | |
import logging | |
import numpy as np | |
import operator | |
import pickle | |
import torch.utils.data | |
from fvcore.common.file_io import PathManager | |
from tabulate import tabulate | |
from termcolor import colored | |
from detectron2.structures import BoxMode | |
from detectron2.utils.comm import get_world_size | |
from detectron2.utils.env import seed_all_rng | |
from detectron2.utils.logger import log_first_n | |
from . import samplers | |
from .catalog import DatasetCatalog, MetadataCatalog | |
from .common import AspectRatioGroupedDataset, DatasetFromList, MapDataset | |
from .dataset_mapper import DatasetMapper | |
from .detection_utils import check_metadata_consistency | |
""" | |
This file contains the default logic to build a dataloader for training or testing. | |
""" | |
__all__ = [ | |
"build_detection_train_loader", | |
"build_detection_test_loader", | |
"get_detection_dataset_dicts", | |
"load_proposals_into_dataset", | |
"print_instances_class_histogram", | |
] | |
def filter_images_with_only_crowd_annotations(dataset_dicts): | |
""" | |
Filter out images with none annotations or only crowd annotations | |
(i.e., images without non-crowd annotations). | |
A common training-time preprocessing on COCO dataset. | |
Args: | |
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
Returns: | |
list[dict]: the same format, but filtered. | |
""" | |
num_before = len(dataset_dicts) | |
def valid(anns): | |
for ann in anns: | |
if ann.get("iscrowd", 0) == 0: | |
return True | |
return False | |
dataset_dicts = [x for x in dataset_dicts if valid(x["annotations"])] | |
num_after = len(dataset_dicts) | |
logger = logging.getLogger(__name__) | |
logger.info( | |
"Removed {} images with no usable annotations. {} images left.".format( | |
num_before - num_after, num_after | |
) | |
) | |
return dataset_dicts | |
def filter_images_with_few_keypoints(dataset_dicts, min_keypoints_per_image): | |
""" | |
Filter out images with too few number of keypoints. | |
Args: | |
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
Returns: | |
list[dict]: the same format as dataset_dicts, but filtered. | |
""" | |
num_before = len(dataset_dicts) | |
def visible_keypoints_in_image(dic): | |
# Each keypoints field has the format [x1, y1, v1, ...], where v is visibility | |
annotations = dic["annotations"] | |
return sum( | |
(np.array(ann["keypoints"][2::3]) > 0).sum() | |
for ann in annotations | |
if "keypoints" in ann | |
) | |
dataset_dicts = [ | |
x for x in dataset_dicts if visible_keypoints_in_image(x) >= min_keypoints_per_image | |
] | |
num_after = len(dataset_dicts) | |
logger = logging.getLogger(__name__) | |
logger.info( | |
"Removed {} images with fewer than {} keypoints.".format( | |
num_before - num_after, min_keypoints_per_image | |
) | |
) | |
return dataset_dicts | |
def load_proposals_into_dataset(dataset_dicts, proposal_file): | |
""" | |
Load precomputed object proposals into the dataset. | |
The proposal file should be a pickled dict with the following keys: | |
- "ids": list[int] or list[str], the image ids | |
- "boxes": list[np.ndarray], each is an Nx4 array of boxes corresponding to the image id | |
- "objectness_logits": list[np.ndarray], each is an N sized array of objectness scores | |
corresponding to the boxes. | |
- "bbox_mode": the BoxMode of the boxes array. Defaults to ``BoxMode.XYXY_ABS``. | |
Args: | |
dataset_dicts (list[dict]): annotations in Detectron2 Dataset format. | |
proposal_file (str): file path of pre-computed proposals, in pkl format. | |
Returns: | |
list[dict]: the same format as dataset_dicts, but added proposal field. | |
""" | |
logger = logging.getLogger(__name__) | |
logger.info("Loading proposals from: {}".format(proposal_file)) | |
with PathManager.open(proposal_file, "rb") as f: | |
proposals = pickle.load(f, encoding="latin1") | |
# Rename the key names in D1 proposal files | |
rename_keys = {"indexes": "ids", "scores": "objectness_logits"} | |
for key in rename_keys: | |
if key in proposals: | |
proposals[rename_keys[key]] = proposals.pop(key) | |
# Fetch the indexes of all proposals that are in the dataset | |
# Convert image_id to str since they could be int. | |
img_ids = set({str(record["image_id"]) for record in dataset_dicts}) | |
id_to_index = {str(id): i for i, id in enumerate(proposals["ids"]) if str(id) in img_ids} | |
# Assuming default bbox_mode of precomputed proposals are 'XYXY_ABS' | |
bbox_mode = BoxMode(proposals["bbox_mode"]) if "bbox_mode" in proposals else BoxMode.XYXY_ABS | |
for record in dataset_dicts: | |
# Get the index of the proposal | |
i = id_to_index[str(record["image_id"])] | |
boxes = proposals["boxes"][i] | |
objectness_logits = proposals["objectness_logits"][i] | |
# Sort the proposals in descending order of the scores | |
inds = objectness_logits.argsort()[::-1] | |
record["proposal_boxes"] = boxes[inds] | |
record["proposal_objectness_logits"] = objectness_logits[inds] | |
record["proposal_bbox_mode"] = bbox_mode | |
return dataset_dicts | |
def _quantize(x, bin_edges): | |
bin_edges = copy.copy(bin_edges) | |
bin_edges = sorted(bin_edges) | |
quantized = list(map(lambda y: bisect.bisect_right(bin_edges, y), x)) | |
return quantized | |
def print_instances_class_histogram(dataset_dicts, class_names): | |
""" | |
Args: | |
dataset_dicts (list[dict]): list of dataset dicts. | |
class_names (list[str]): list of class names (zero-indexed). | |
""" | |
num_classes = len(class_names) | |
hist_bins = np.arange(num_classes + 1) | |
histogram = np.zeros((num_classes,), dtype=np.int) | |
for entry in dataset_dicts: | |
annos = entry["annotations"] | |
classes = [x["category_id"] for x in annos if not x.get("iscrowd", 0)] | |
histogram += np.histogram(classes, bins=hist_bins)[0] | |
N_COLS = min(6, len(class_names) * 2) | |
def short_name(x): | |
# make long class names shorter. useful for lvis | |
if len(x) > 13: | |
return x[:11] + ".." | |
return x | |
data = list( | |
itertools.chain(*[[short_name(class_names[i]), int(v)] for i, v in enumerate(histogram)]) | |
) | |
total_num_instances = sum(data[1::2]) | |
data.extend([None] * (N_COLS - (len(data) % N_COLS))) | |
if num_classes > 1: | |
data.extend(["total", total_num_instances]) | |
data = itertools.zip_longest(*[data[i::N_COLS] for i in range(N_COLS)]) | |
table = tabulate( | |
data, | |
headers=["category", "#instances"] * (N_COLS // 2), | |
tablefmt="pipe", | |
numalign="left", | |
stralign="center", | |
) | |
log_first_n( | |
logging.INFO, | |
"Distribution of instances among all {} categories:\n".format(num_classes) | |
+ colored(table, "cyan"), | |
key="message", | |
) | |
def get_detection_dataset_dicts( | |
dataset_names, filter_empty=True, min_keypoints=0, proposal_files=None | |
): | |
""" | |
Load and prepare dataset dicts for instance detection/segmentation and semantic segmentation. | |
Args: | |
dataset_names (list[str]): a list of dataset names | |
filter_empty (bool): whether to filter out images without instance annotations | |
min_keypoints (int): filter out images with fewer keypoints than | |
`min_keypoints`. Set to 0 to do nothing. | |
proposal_files (list[str]): if given, a list of object proposal files | |
that match each dataset in `dataset_names`. | |
""" | |
assert len(dataset_names) | |
dataset_dicts = [DatasetCatalog.get(dataset_name) for dataset_name in dataset_names] | |
for dataset_name, dicts in zip(dataset_names, dataset_dicts): | |
assert len(dicts), "Dataset '{}' is empty!".format(dataset_name) | |
if proposal_files is not None: | |
assert len(dataset_names) == len(proposal_files) | |
# load precomputed proposals from proposal files | |
dataset_dicts = [ | |
load_proposals_into_dataset(dataset_i_dicts, proposal_file) | |
for dataset_i_dicts, proposal_file in zip(dataset_dicts, proposal_files) | |
] | |
dataset_dicts = list(itertools.chain.from_iterable(dataset_dicts)) | |
has_instances = "annotations" in dataset_dicts[0] | |
# Keep images without instance-level GT if the dataset has semantic labels. | |
if filter_empty and has_instances and "sem_seg_file_name" not in dataset_dicts[0]: | |
dataset_dicts = filter_images_with_only_crowd_annotations(dataset_dicts) | |
if min_keypoints > 0 and has_instances: | |
dataset_dicts = filter_images_with_few_keypoints(dataset_dicts, min_keypoints) | |
if has_instances: | |
try: | |
class_names = MetadataCatalog.get(dataset_names[0]).thing_classes | |
check_metadata_consistency("thing_classes", dataset_names) | |
print_instances_class_histogram(dataset_dicts, class_names) | |
except AttributeError: # class names are not available for this dataset | |
pass | |
return dataset_dicts | |
def build_detection_train_loader(cfg, mapper=None): | |
""" | |
A data loader is created by the following steps: | |
1. Use the dataset names in config to query :class:`DatasetCatalog`, and obtain a list of dicts. | |
2. Coordinate a random shuffle order shared among all processes (all GPUs) | |
3. Each process spawn another few workers to process the dicts. Each worker will: | |
* Map each metadata dict into another format to be consumed by the model. | |
* Batch them by simply putting dicts into a list. | |
The batched ``list[mapped_dict]`` is what this dataloader will yield. | |
Args: | |
cfg (CfgNode): the config | |
mapper (callable): a callable which takes a sample (dict) from dataset and | |
returns the format to be consumed by the model. | |
By default it will be `DatasetMapper(cfg, True)`. | |
Returns: | |
an infinite iterator of training data | |
""" | |
num_workers = get_world_size() | |
images_per_batch = cfg.SOLVER.IMS_PER_BATCH | |
assert ( | |
images_per_batch % num_workers == 0 | |
), "SOLVER.IMS_PER_BATCH ({}) must be divisible by the number of workers ({}).".format( | |
images_per_batch, num_workers | |
) | |
assert ( | |
images_per_batch >= num_workers | |
), "SOLVER.IMS_PER_BATCH ({}) must be larger than the number of workers ({}).".format( | |
images_per_batch, num_workers | |
) | |
images_per_worker = images_per_batch // num_workers | |
dataset_dicts = get_detection_dataset_dicts( | |
cfg.DATASETS.TRAIN, | |
filter_empty=cfg.DATALOADER.FILTER_EMPTY_ANNOTATIONS, | |
min_keypoints=cfg.MODEL.ROI_KEYPOINT_HEAD.MIN_KEYPOINTS_PER_IMAGE | |
if cfg.MODEL.KEYPOINT_ON | |
else 0, | |
proposal_files=cfg.DATASETS.PROPOSAL_FILES_TRAIN if cfg.MODEL.LOAD_PROPOSALS else None, | |
) | |
dataset = DatasetFromList(dataset_dicts, copy=False) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, True) | |
dataset = MapDataset(dataset, mapper) | |
sampler_name = cfg.DATALOADER.SAMPLER_TRAIN | |
logger = logging.getLogger(__name__) | |
logger.info("Using training sampler {}".format(sampler_name)) | |
if sampler_name == "TrainingSampler": | |
sampler = samplers.TrainingSampler(len(dataset)) | |
elif sampler_name == "RepeatFactorTrainingSampler": | |
sampler = samplers.RepeatFactorTrainingSampler( | |
dataset_dicts, cfg.DATALOADER.REPEAT_THRESHOLD | |
) | |
else: | |
raise ValueError("Unknown training sampler: {}".format(sampler_name)) | |
if cfg.DATALOADER.ASPECT_RATIO_GROUPING: | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
sampler=sampler, | |
num_workers=cfg.DATALOADER.NUM_WORKERS, | |
batch_sampler=None, | |
collate_fn=operator.itemgetter(0), # don't batch, but yield individual elements | |
worker_init_fn=worker_init_reset_seed, | |
) # yield individual mapped dict | |
data_loader = AspectRatioGroupedDataset(data_loader, images_per_worker) | |
else: | |
batch_sampler = torch.utils.data.sampler.BatchSampler( | |
sampler, images_per_worker, drop_last=True | |
) | |
# drop_last so the batch always have the same size | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
num_workers=cfg.DATALOADER.NUM_WORKERS, | |
batch_sampler=batch_sampler, | |
collate_fn=trivial_batch_collator, | |
worker_init_fn=worker_init_reset_seed, | |
) | |
return data_loader | |
def build_detection_test_loader(cfg, dataset_name, mapper=None): | |
""" | |
Similar to `build_detection_train_loader`. | |
But this function uses the given `dataset_name` argument (instead of the names in cfg), | |
and uses batch size 1. | |
Args: | |
cfg: a detectron2 CfgNode | |
dataset_name (str): a name of the dataset that's available in the DatasetCatalog | |
mapper (callable): a callable which takes a sample (dict) from dataset | |
and returns the format to be consumed by the model. | |
By default it will be `DatasetMapper(cfg, False)`. | |
Returns: | |
DataLoader: a torch DataLoader, that loads the given detection | |
dataset, with test-time transformation and batching. | |
""" | |
dataset_dicts = get_detection_dataset_dicts( | |
[dataset_name], | |
filter_empty=False, | |
proposal_files=[ | |
cfg.DATASETS.PROPOSAL_FILES_TEST[list(cfg.DATASETS.TEST).index(dataset_name)] | |
] | |
if cfg.MODEL.LOAD_PROPOSALS | |
else None, | |
) | |
dataset = DatasetFromList(dataset_dicts) | |
if mapper is None: | |
mapper = DatasetMapper(cfg, False) | |
dataset = MapDataset(dataset, mapper) | |
sampler = samplers.InferenceSampler(len(dataset)) | |
# Always use 1 image per worker during inference since this is the | |
# standard when reporting inference time in papers. | |
batch_sampler = torch.utils.data.sampler.BatchSampler(sampler, 1, drop_last=False) | |
data_loader = torch.utils.data.DataLoader( | |
dataset, | |
num_workers=cfg.DATALOADER.NUM_WORKERS, | |
batch_sampler=batch_sampler, | |
collate_fn=trivial_batch_collator, | |
) | |
return data_loader | |
def trivial_batch_collator(batch): | |
""" | |
A batch collator that does nothing. | |
""" | |
return batch | |
def worker_init_reset_seed(worker_id): | |
seed_all_rng(np.random.randint(2 ** 31) + worker_id) | |