UniversalAlgorithmic's picture
Upload 185 files
2215b89 verified
raw
history blame
3.84 kB
import copy
import os
import torch
import torch.utils.data
import torchvision
from PIL import Image
from pycocotools import mask as coco_mask
from transforms import Compose
class FilterAndRemapCocoCategories:
def __init__(self, categories, remap=True):
self.categories = categories
self.remap = remap
def __call__(self, image, anno):
anno = [obj for obj in anno if obj["category_id"] in self.categories]
if not self.remap:
return image, anno
anno = copy.deepcopy(anno)
for obj in anno:
obj["category_id"] = self.categories.index(obj["category_id"])
return image, anno
def convert_coco_poly_to_mask(segmentations, height, width):
masks = []
for polygons in segmentations:
rles = coco_mask.frPyObjects(polygons, height, width)
mask = coco_mask.decode(rles)
if len(mask.shape) < 3:
mask = mask[..., None]
mask = torch.as_tensor(mask, dtype=torch.uint8)
mask = mask.any(dim=2)
masks.append(mask)
if masks:
masks = torch.stack(masks, dim=0)
else:
masks = torch.zeros((0, height, width), dtype=torch.uint8)
return masks
class ConvertCocoPolysToMask:
def __call__(self, image, anno):
w, h = image.size
segmentations = [obj["segmentation"] for obj in anno]
cats = [obj["category_id"] for obj in anno]
if segmentations:
masks = convert_coco_poly_to_mask(segmentations, h, w)
cats = torch.as_tensor(cats, dtype=masks.dtype)
# merge all instance masks into a single segmentation map
# with its corresponding categories
target, _ = (masks * cats[:, None, None]).max(dim=0)
# discard overlapping instances
target[masks.sum(0) > 1] = 255
else:
target = torch.zeros((h, w), dtype=torch.uint8)
target = Image.fromarray(target.numpy())
return image, target
def _coco_remove_images_without_annotations(dataset, cat_list=None):
def _has_valid_annotation(anno):
# if it's empty, there is no annotation
if len(anno) == 0:
return False
# if more than 1k pixels occupied in the image
return sum(obj["area"] for obj in anno) > 1000
if not isinstance(dataset, torchvision.datasets.CocoDetection):
raise TypeError(
f"This function expects dataset of type torchvision.datasets.CocoDetection, instead got {type(dataset)}"
)
ids = []
for ds_idx, img_id in enumerate(dataset.ids):
ann_ids = dataset.coco.getAnnIds(imgIds=img_id, iscrowd=None)
anno = dataset.coco.loadAnns(ann_ids)
if cat_list:
anno = [obj for obj in anno if obj["category_id"] in cat_list]
if _has_valid_annotation(anno):
ids.append(ds_idx)
dataset = torch.utils.data.Subset(dataset, ids)
return dataset
def get_coco(root, image_set, transforms):
PATHS = {
"train": ("train2017", os.path.join("annotations", "instances_train2017.json")),
"val": ("val2017", os.path.join("annotations", "instances_val2017.json")),
# "train": ("val2017", os.path.join("annotations", "instances_val2017.json"))
}
CAT_LIST = [0, 5, 2, 16, 9, 44, 6, 3, 17, 62, 21, 67, 18, 19, 4, 1, 64, 20, 63, 7, 72]
transforms = Compose([FilterAndRemapCocoCategories(CAT_LIST, remap=True), ConvertCocoPolysToMask(), transforms])
img_folder, ann_file = PATHS[image_set]
img_folder = os.path.join(root, img_folder)
ann_file = os.path.join(root, ann_file)
dataset = torchvision.datasets.CocoDetection(img_folder, ann_file, transforms=transforms)
if image_set == "train":
dataset = _coco_remove_images_without_annotations(dataset, CAT_LIST)
return dataset