File size: 5,140 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
""" Pascal VOC dataset parser

Copyright 2020 Ross Wightman
"""
import os
import xml.etree.ElementTree as ET
from collections import defaultdict
import numpy as np

from .parser import Parser
from .parser_config import VocParserCfg


class VocParser(Parser):

    DEFAULT_CLASSES = (
        'aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', 'cat', 'chair',
        'cow', 'diningtable', 'dog', 'horse', 'motorbike', 'person', 'pottedplant',
        'sheep', 'sofa', 'train', 'tvmonitor')

    def __init__(self, cfg: VocParserCfg):
        super().__init__(
            bbox_yxyx=cfg.bbox_yxyx,
            has_labels=cfg.has_labels,
            include_masks=False,  # FIXME to support someday
            include_bboxes_ignore=False,
            ignore_empty_gt=cfg.has_labels and cfg.ignore_empty_gt,
            min_img_size=cfg.min_img_size
        )
        self.correct_bbox = 1
        self.keep_difficult = cfg.keep_difficult

        self.anns = None
        self.img_id_to_idx = {}
        self._load_annotations(
            split_filename=cfg.split_filename,
            img_filename=cfg.img_filename,
            ann_filename=cfg.ann_filename,
            classes=cfg.classes,
        )

    def _load_annotations(
            self,
            split_filename: str,
            img_filename: str,
            ann_filename: str,
            classes=None,
    ):
        classes = classes or self.DEFAULT_CLASSES
        self.cat_names = list(classes)
        self.cat_ids = self.cat_names
        self.cat_id_to_label = {cat: i + self.label_offset for i, cat in enumerate(self.cat_ids)}

        self.anns = []

        with open(split_filename) as f:
            ids = f.readlines()
        for img_id in ids:
            img_id = img_id.strip("\n")
            filename = img_filename % img_id
            xml_path = ann_filename % img_id
            tree = ET.parse(xml_path)
            root = tree.getroot()
            size = root.find('size')
            width = int(size.find('width').text)
            height = int(size.find('height').text)
            if min(width, height) < self.min_img_size:
                continue

            anns = []
            for obj_idx, obj in enumerate(root.findall('object')):
                name = obj.find('name').text
                label = self.cat_id_to_label[name]
                difficult = int(obj.find('difficult').text)
                bnd_box = obj.find('bndbox')
                bbox = [
                    int(bnd_box.find('xmin').text),
                    int(bnd_box.find('ymin').text),
                    int(bnd_box.find('xmax').text),
                    int(bnd_box.find('ymax').text)
                ]
                anns.append(dict(label=label, bbox=bbox, difficult=difficult))

            if not self.ignore_empty_gt or len(anns):
                self.anns.append(anns)
                self.img_infos.append(dict(id=img_id, file_name=filename, width=width, height=height))
                self.img_ids.append(img_id)
            else:
                self.img_ids_invalid.append(img_id)

    def merge(self, other):
        assert len(self.cat_ids) == len(other.cat_ids)
        self.img_ids.extend(other.img_ids)
        self.img_infos.extend(other.img_infos)
        self.anns.extend(other.anns)

    def get_ann_info(self, idx):
        return self._parse_ann_info(self.anns[idx])

    def _parse_ann_info(self, ann_info):
        bboxes = []
        labels = []
        bboxes_ignore = []
        labels_ignore = []
        for ann in ann_info:
            ignore = False
            x1, y1, x2, y2 = ann['bbox']
            label = ann['label']
            w = x2 - x1
            h = y2 - y1
            if w < 1 or h < 1:
                ignore = True
            if self.yxyx:
                bbox = [y1, x1, y2, x2]
            else:
                bbox = ann['bbox']
            if ignore or (ann['difficult'] and not self.keep_difficult):
                bboxes_ignore.append(bbox)
                labels_ignore.append(label)
            else:
                bboxes.append(bbox)
                labels.append(label)

        if not bboxes:
            bboxes = np.zeros((0, 4), dtype=np.float32)
            labels = np.zeros((0, ), dtype=np.float32)
        else:
            bboxes = np.array(bboxes, ndmin=2, dtype=np.float32) - self.correct_bbox
            labels = np.array(labels, dtype=np.float32)

        if self.include_bboxes_ignore:
            if not bboxes_ignore:
                bboxes_ignore = np.zeros((0, 4), dtype=np.float32)
                labels_ignore = np.zeros((0, ), dtype=np.float32)
            else:
                bboxes_ignore = np.array(bboxes_ignore, ndmin=2, dtype=np.float32) - self.correct_bbox
                labels_ignore = np.array(labels_ignore, dtype=np.float32)

        ann = dict(
            bbox=bboxes.astype(np.float32),
            cls=labels.astype(np.int64))

        if self.include_bboxes_ignore:
            ann.update(dict(
                bbox_ignore=bboxes_ignore.astype(np.float32),
                cls_ignore=labels_ignore.astype(np.int64)))
        return ann