Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 |
""" Pascal VOC dataset parser
Copyright 2020 Ross Wightman
"""
import os
import xml.etree.ElementTree as ET
from collections import defaultdict
import numpy as np
from .parser import Parser
from .parser_config import VocParserCfg
class VocParser(Parser):
DEFAULT_CLASSES = (
'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
'sheep', 'sofa', 'train', 'tvmonitor')
def __init__(self, cfg: VocParserCfg):
super().__init__(
bbox_yxyx=cfg.bbox_yxyx,
has_labels=cfg.has_labels,
include_masks=False, # FIXME to support someday
include_bboxes_ignore=False,
ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
min_img_size=cfg.min_img_size
)
self.correct_bbox = 1
self.keep_difficult = cfg.keep_difficult
self.anns = None
self.img_id_to_idx = {}
self._load_annotations(
split_filename=cfg.split_filename,
img_filename=cfg.img_filename,
ann_filename=cfg.ann_filename,
classes=cfg.classes,
)
def _load_annotations(
self,
split_filename: str,
img_filename: str,
ann_filename: str,
classes=None,
):
classes = classes or self.DEFAULT_CLASSES
self.cat_names = list(classes)
self.cat_ids = self.cat_names
self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)}
self.anns = []
with open(split_filename) as f:
ids = f.readlines()
for img_id in ids:
img_id = img_id.strip("\n")
filename = img_filename % img_id
xml_path = ann_filename % img_id
tree = ET.parse(xml_path)
root = tree.getroot()
size = root.find('size')
width = int(size.find('width').text)
height = int(size.find('height').text)
if min(width, height) < self.min_img_size:
continue
anns = []
for obj_idx, obj in enumerate(root.findall('object')):
name = obj.find('name').text
label = self.cat_id_to_label[name]
difficult = int(obj.find('difficult').text)
bnd_box = obj.find('bndbox')
bbox = [
int(bnd_box.find('xmin').text),
int(bnd_box.find('ymin').text),
int(bnd_box.find('xmax').text),
int(bnd_box.find('ymax').text)
]
anns.append(dict(label=label, bbox=bbox, difficult=difficult))
if not self.ignore_empty_gt or len(anns):
self.anns.append(anns)
self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height))
self.img_ids.append(img_id)
else:
self.img_ids_invalid.append(img_id)
def merge(self, other):
assert len(self.cat_ids) == len(other.cat_ids)
self.img_ids.extend(other.img_ids)
self.img_infos.extend(other.img_infos)
self.anns.extend(other.anns)
def get_ann_info(self, idx):
return self._parse_ann_info(self.anns[idx])
def _parse_ann_info(self, ann_info):
bboxes = []
labels = []
bboxes_ignore = []
labels_ignore = []
for ann in ann_info:
ignore = False
x1, y1, x2, y2 = ann['bbox']
label = ann['label']
w = x2 - x1
h = y2 - y1
if w < 1 or h < 1:
ignore = True
if self.yxyx:
bbox = [y1, x1, y2, x2]
else:
bbox = ann['bbox']
if ignore or (ann['difficult'] and not self.keep_difficult):
bboxes_ignore.append(bbox)
labels_ignore.append(label)
else:
bboxes.append(bbox)
labels.append(label)
if not bboxes:
bboxes = np.zeros((0, 4), dtype=np.float32)
labels = np.zeros((0, ), dtype=np.float32)
else:
bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox
labels = np.array(labels, dtype=np.float32)
if self.include_bboxes_ignore:
if not bboxes_ignore:
bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
labels_ignore = np.zeros((0, ), dtype=np.float32)
else:
bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox
labels_ignore = np.array(labels_ignore, dtype=np.float32)
ann = dict(
bbox=bboxes.astype(np.float32),
cls=labels.astype(np.int64))
if self.include_bboxes_ignore:
ann.update(dict(
bbox_ignore=bboxes_ignore.astype(np.float32),
cls_ignore=labels_ignore.astype(np.int64)))
return ann
|