|
""" |
|
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved |
|
|
|
COCO dataset which returns image_id for evaluation. |
|
Mostly copy-paste from https://github.com/pytorch/vision/blob/13b35ff/references/detection/coco_utils.py |
|
""" |
|
|
|
import torch |
|
import torch.utils.data |
|
|
|
import torchvision |
|
torchvision.disable_beta_transforms_warning() |
|
|
|
from torchvision import datapoints |
|
|
|
from pycocotools import mask as coco_mask |
|
|
|
from src.core import register |
|
|
|
__all__ = ['CocoDetection'] |
|
|
|
|
|
@register |
|
class CocoDetection(torchvision.datasets.CocoDetection): |
|
__inject__ = ['transforms'] |
|
__share__ = ['remap_mscoco_category'] |
|
|
|
def __init__(self, img_folder, ann_file, transforms, return_masks, remap_mscoco_category=False): |
|
super(CocoDetection, self).__init__(img_folder, ann_file) |
|
self._transforms = transforms |
|
self.prepare = ConvertCocoPolysToMask(return_masks, remap_mscoco_category) |
|
self.img_folder = img_folder |
|
self.ann_file = ann_file |
|
self.return_masks = return_masks |
|
self.remap_mscoco_category = remap_mscoco_category |
|
|
|
def __getitem__(self, idx): |
|
img, target = super(CocoDetection, self).__getitem__(idx) |
|
image_id = self.ids[idx] |
|
target = {'image_id': image_id, 'annotations': target} |
|
img, target = self.prepare(img, target) |
|
|
|
|
|
if 'boxes' in target: |
|
target['boxes'] = datapoints.BoundingBox( |
|
target['boxes'], |
|
format=datapoints.BoundingBoxFormat.XYXY, |
|
spatial_size=img.size[::-1]) |
|
|
|
if 'masks' in target: |
|
target['masks'] = datapoints.Mask(target['masks']) |
|
|
|
if self._transforms is not None: |
|
img, target = self._transforms(img, target) |
|
|
|
return img, target |
|
|
|
def extra_repr(self) -> str: |
|
s = f' img_folder: {self.img_folder}\n ann_file: {self.ann_file}\n' |
|
s += f' return_masks: {self.return_masks}\n' |
|
if hasattr(self, '_transforms') and self._transforms is not None: |
|
s += f' transforms:\n {repr(self._transforms)}' |
|
|
|
return s |
|
|
|
|
|
def convert_coco_poly_to_mask(segmentations, height, width): |
|
masks = [] |
|
for polygons in segmentations: |
|
rles = coco_mask.frPyObjects(polygons, height, width) |
|
mask = coco_mask.decode(rles) |
|
if len(mask.shape) < 3: |
|
mask = mask[..., None] |
|
mask = torch.as_tensor(mask, dtype=torch.uint8) |
|
mask = mask.any(dim=2) |
|
masks.append(mask) |
|
if masks: |
|
masks = torch.stack(masks, dim=0) |
|
else: |
|
masks = torch.zeros((0, height, width), dtype=torch.uint8) |
|
return masks |
|
|
|
|
|
class ConvertCocoPolysToMask(object): |
|
def __init__(self, return_masks=False, remap_mscoco_category=False): |
|
self.return_masks = return_masks |
|
self.remap_mscoco_category = remap_mscoco_category |
|
|
|
def __call__(self, image, target): |
|
w, h = image.size |
|
|
|
image_id = target["image_id"] |
|
image_id = torch.tensor([image_id]) |
|
|
|
anno = target["annotations"] |
|
|
|
anno = [obj for obj in anno if 'iscrowd' not in obj or obj['iscrowd'] == 0] |
|
|
|
boxes = [obj["bbox"] for obj in anno] |
|
|
|
boxes = torch.as_tensor(boxes, dtype=torch.float32).reshape(-1, 4) |
|
boxes[:, 2:] += boxes[:, :2] |
|
boxes[:, 0::2].clamp_(min=0, max=w) |
|
boxes[:, 1::2].clamp_(min=0, max=h) |
|
|
|
if self.remap_mscoco_category: |
|
classes = [mscoco_category2label[obj["category_id"]] for obj in anno] |
|
else: |
|
classes = [obj["category_id"] for obj in anno] |
|
|
|
classes = torch.tensor(classes, dtype=torch.int64) |
|
|
|
if self.return_masks: |
|
segmentations = [obj["segmentation"] for obj in anno] |
|
masks = convert_coco_poly_to_mask(segmentations, h, w) |
|
|
|
keypoints = None |
|
if anno and "keypoints" in anno[0]: |
|
keypoints = [obj["keypoints"] for obj in anno] |
|
keypoints = torch.as_tensor(keypoints, dtype=torch.float32) |
|
num_keypoints = keypoints.shape[0] |
|
if num_keypoints: |
|
keypoints = keypoints.view(num_keypoints, -1, 3) |
|
|
|
keep = (boxes[:, 3] > boxes[:, 1]) & (boxes[:, 2] > boxes[:, 0]) |
|
boxes = boxes[keep] |
|
classes = classes[keep] |
|
if self.return_masks: |
|
masks = masks[keep] |
|
if keypoints is not None: |
|
keypoints = keypoints[keep] |
|
|
|
target = {} |
|
target["boxes"] = boxes |
|
target["labels"] = classes |
|
if self.return_masks: |
|
target["masks"] = masks |
|
target["image_id"] = image_id |
|
if keypoints is not None: |
|
target["keypoints"] = keypoints |
|
|
|
|
|
area = torch.tensor([obj["area"] for obj in anno]) |
|
iscrowd = torch.tensor([obj["iscrowd"] if "iscrowd" in obj else 0 for obj in anno]) |
|
target["area"] = area[keep] |
|
target["iscrowd"] = iscrowd[keep] |
|
|
|
target["orig_size"] = torch.as_tensor([int(w), int(h)]) |
|
target["size"] = torch.as_tensor([int(w), int(h)]) |
|
|
|
return image, target |
|
|
|
|
|
mscoco_category2name = { |
|
1: 'person', |
|
2: 'bicycle', |
|
3: 'car', |
|
4: 'motorcycle', |
|
5: 'airplane', |
|
6: 'bus', |
|
7: 'train', |
|
8: 'truck', |
|
9: 'boat', |
|
10: 'traffic light', |
|
11: 'fire hydrant', |
|
13: 'stop sign', |
|
14: 'parking meter', |
|
15: 'bench', |
|
16: 'bird', |
|
17: 'cat', |
|
18: 'dog', |
|
19: 'horse', |
|
20: 'sheep', |
|
21: 'cow', |
|
22: 'elephant', |
|
23: 'bear', |
|
24: 'zebra', |
|
25: 'giraffe', |
|
27: 'backpack', |
|
28: 'umbrella', |
|
31: 'handbag', |
|
32: 'tie', |
|
33: 'suitcase', |
|
34: 'frisbee', |
|
35: 'skis', |
|
36: 'snowboard', |
|
37: 'sports ball', |
|
38: 'kite', |
|
39: 'baseball bat', |
|
40: 'baseball glove', |
|
41: 'skateboard', |
|
42: 'surfboard', |
|
43: 'tennis racket', |
|
44: 'bottle', |
|
46: 'wine glass', |
|
47: 'cup', |
|
48: 'fork', |
|
49: 'knife', |
|
50: 'spoon', |
|
51: 'bowl', |
|
52: 'banana', |
|
53: 'apple', |
|
54: 'sandwich', |
|
55: 'orange', |
|
56: 'broccoli', |
|
57: 'carrot', |
|
58: 'hot dog', |
|
59: 'pizza', |
|
60: 'donut', |
|
61: 'cake', |
|
62: 'chair', |
|
63: 'couch', |
|
64: 'potted plant', |
|
65: 'bed', |
|
67: 'dining table', |
|
70: 'toilet', |
|
72: 'tv', |
|
73: 'laptop', |
|
74: 'mouse', |
|
75: 'remote', |
|
76: 'keyboard', |
|
77: 'cell phone', |
|
78: 'microwave', |
|
79: 'oven', |
|
80: 'toaster', |
|
81: 'sink', |
|
82: 'refrigerator', |
|
84: 'book', |
|
85: 'clock', |
|
86: 'vase', |
|
87: 'scissors', |
|
88: 'teddy bear', |
|
89: 'hair drier', |
|
90: 'toothbrush' |
|
} |
|
|
|
mscoco_category2label = {k: i for i, k in enumerate(mscoco_category2name.keys())} |
|
mscoco_label2category = {v: k for k, v in mscoco_category2label.items()} |