""" Pascal VOC dataset parser Copyright 2020 Ross Wightman """ import os import xml.etree.ElementTree as ET from collections import defaultdict import numpy as np from .parser import Parser from .parser_config import VocParserCfg class VocParser(Parser): DEFAULT_CLASSES = ( 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', 'tvmonitor') def __init__(self, cfg: VocParserCfg): super().__init__( bbox_yxyx=cfg.bbox_yxyx, has_labels=cfg.has_labels, include_masks=False, # FIXME to support someday include_bboxes_ignore=False, ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt, min_img_size=cfg.min_img_size ) self.correct_bbox = 1 self.keep_difficult = cfg.keep_difficult self.anns = None self.img_id_to_idx = {} self._load_annotations( split_filename=cfg.split_filename, img_filename=cfg.img_filename, ann_filename=cfg.ann_filename, classes=cfg.classes, ) def _load_annotations( self, split_filename: str, img_filename: str, ann_filename: str, classes=None, ): classes = classes or self.DEFAULT_CLASSES self.cat_names = list(classes) self.cat_ids = self.cat_names self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)} self.anns = [] with open(split_filename) as f: ids = f.readlines() for img_id in ids: img_id = img_id.strip("\n") filename = img_filename % img_id xml_path = ann_filename % img_id tree = ET.parse(xml_path) root = tree.getroot() size = root.find('size') width = int(size.find('width').text) height = int(size.find('height').text) if min(width, height) < self.min_img_size: continue anns = [] for obj_idx, obj in enumerate(root.findall('object')): name = obj.find('name').text label = self.cat_id_to_label[name] difficult = int(obj.find('difficult').text) bnd_box = obj.find('bndbox') bbox = [ int(bnd_box.find('xmin').text), int(bnd_box.find('ymin').text), int(bnd_box.find('xmax').text), int(bnd_box.find('ymax').text) ] anns.append(dict(label=label, bbox=bbox, difficult=difficult)) if not self.ignore_empty_gt or len(anns): self.anns.append(anns) self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height)) self.img_ids.append(img_id) else: self.img_ids_invalid.append(img_id) def merge(self, other): assert len(self.cat_ids) == len(other.cat_ids) self.img_ids.extend(other.img_ids) self.img_infos.extend(other.img_infos) self.anns.extend(other.anns) def get_ann_info(self, idx): return self._parse_ann_info(self.anns[idx]) def _parse_ann_info(self, ann_info): bboxes = [] labels = [] bboxes_ignore = [] labels_ignore = [] for ann in ann_info: ignore = False x1, y1, x2, y2 = ann['bbox'] label = ann['label'] w = x2 - x1 h = y2 - y1 if w < 1 or h < 1: ignore = True if self.yxyx: bbox = [y1, x1, y2, x2] else: bbox = ann['bbox'] if ignore or (ann['difficult'] and not self.keep_difficult): bboxes_ignore.append(bbox) labels_ignore.append(label) else: bboxes.append(bbox) labels.append(label) if not bboxes: bboxes = np.zeros((0, 4), dtype=np.float32) labels = np.zeros((0, ), dtype=np.float32) else: bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox labels = np.array(labels, dtype=np.float32) if self.include_bboxes_ignore: if not bboxes_ignore: bboxes_ignore = np.zeros((0, 4), dtype=np.float32) labels_ignore = np.zeros((0, ), dtype=np.float32) else: bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox labels_ignore = np.array(labels_ignore, dtype=np.float32) ann = dict( bbox=bboxes.astype(np.float32), cls=labels.astype(np.int64)) if self.include_bboxes_ignore: ann.update(dict( bbox_ignore=bboxes_ignore.astype(np.float32), cls_ignore=labels_ignore.astype(np.int64))) return ann