Spaces:
Sleeping
Sleeping
| """ Pascal VOC dataset parser | |
| Copyright 2020 Ross Wightman | |
| """ | |
| import os | |
| import xml.etree.ElementTree as ET | |
| from collections import defaultdict | |
| import numpy as np | |
| from .parser import Parser | |
| from .parser_config import VocParserCfg | |
| class VocParser(Parser): | |
| DEFAULT_CLASSES = ( | |
| 'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair', | |
| 'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant', | |
| 'sheep', 'sofa', 'train', 'tvmonitor') | |
| def __init__(self, cfg: VocParserCfg): | |
| super().__init__( | |
| bbox_yxyx=cfg.bbox_yxyx, | |
| has_labels=cfg.has_labels, | |
| include_masks=False, # FIXME to support someday | |
| include_bboxes_ignore=False, | |
| ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt, | |
| min_img_size=cfg.min_img_size | |
| ) | |
| self.correct_bbox = 1 | |
| self.keep_difficult = cfg.keep_difficult | |
| self.anns = None | |
| self.img_id_to_idx = {} | |
| self._load_annotations( | |
| split_filename=cfg.split_filename, | |
| img_filename=cfg.img_filename, | |
| ann_filename=cfg.ann_filename, | |
| classes=cfg.classes, | |
| ) | |
| def _load_annotations( | |
| self, | |
| split_filename: str, | |
| img_filename: str, | |
| ann_filename: str, | |
| classes=None, | |
| ): | |
| classes = classes or self.DEFAULT_CLASSES | |
| self.cat_names = list(classes) | |
| self.cat_ids = self.cat_names | |
| self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)} | |
| self.anns = [] | |
| with open(split_filename) as f: | |
| ids = f.readlines() | |
| for img_id in ids: | |
| img_id = img_id.strip("\n") | |
| filename = img_filename % img_id | |
| xml_path = ann_filename % img_id | |
| tree = ET.parse(xml_path) | |
| root = tree.getroot() | |
| size = root.find('size') | |
| width = int(size.find('width').text) | |
| height = int(size.find('height').text) | |
| if min(width, height) < self.min_img_size: | |
| continue | |
| anns = [] | |
| for obj_idx, obj in enumerate(root.findall('object')): | |
| name = obj.find('name').text | |
| label = self.cat_id_to_label[name] | |
| difficult = int(obj.find('difficult').text) | |
| bnd_box = obj.find('bndbox') | |
| bbox = [ | |
| int(bnd_box.find('xmin').text), | |
| int(bnd_box.find('ymin').text), | |
| int(bnd_box.find('xmax').text), | |
| int(bnd_box.find('ymax').text) | |
| ] | |
| anns.append(dict(label=label, bbox=bbox, difficult=difficult)) | |
| if not self.ignore_empty_gt or len(anns): | |
| self.anns.append(anns) | |
| self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height)) | |
| self.img_ids.append(img_id) | |
| else: | |
| self.img_ids_invalid.append(img_id) | |
| def merge(self, other): | |
| assert len(self.cat_ids) == len(other.cat_ids) | |
| self.img_ids.extend(other.img_ids) | |
| self.img_infos.extend(other.img_infos) | |
| self.anns.extend(other.anns) | |
| def get_ann_info(self, idx): | |
| return self._parse_ann_info(self.anns[idx]) | |
| def _parse_ann_info(self, ann_info): | |
| bboxes = [] | |
| labels = [] | |
| bboxes_ignore = [] | |
| labels_ignore = [] | |
| for ann in ann_info: | |
| ignore = False | |
| x1, y1, x2, y2 = ann['bbox'] | |
| label = ann['label'] | |
| w = x2 - x1 | |
| h = y2 - y1 | |
| if w < 1 or h < 1: | |
| ignore = True | |
| if self.yxyx: | |
| bbox = [y1, x1, y2, x2] | |
| else: | |
| bbox = ann['bbox'] | |
| if ignore or (ann['difficult'] and not self.keep_difficult): | |
| bboxes_ignore.append(bbox) | |
| labels_ignore.append(label) | |
| else: | |
| bboxes.append(bbox) | |
| labels.append(label) | |
| if not bboxes: | |
| bboxes = np.zeros((0, 4), dtype=np.float32) | |
| labels = np.zeros((0, ), dtype=np.float32) | |
| else: | |
| bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox | |
| labels = np.array(labels, dtype=np.float32) | |
| if self.include_bboxes_ignore: | |
| if not bboxes_ignore: | |
| bboxes_ignore = np.zeros((0, 4), dtype=np.float32) | |
| labels_ignore = np.zeros((0, ), dtype=np.float32) | |
| else: | |
| bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox | |
| labels_ignore = np.array(labels_ignore, dtype=np.float32) | |
| ann = dict( | |
| bbox=bboxes.astype(np.float32), | |
| cls=labels.astype(np.int64)) | |
| if self.include_bboxes_ignore: | |
| ann.update(dict( | |
| bbox_ignore=bboxes_ignore.astype(np.float32), | |
| cls_ignore=labels_ignore.astype(np.int64))) | |
| return ann | |