File size: 27,721 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Bounding Box List operations for Numpy BoxLists.

Example box operations that are supported:
  * Areas: compute bounding box areas
  * IOU: pairwise intersection-over-union scores
"""
import numpy as np


class BoxList(object):
    """Box collection.
    BoxList represents a list of bounding boxes as numpy array, where each
    bounding box is represented as a row of 4 numbers,
    [y_min, x_min, y_max, x_max].  It is assumed that all bounding boxes within a
    given list correspond to a single image.
    Optionally, users can add additional related fields (such as
    objectness/classification scores).
    """

    def __init__(self, data):
        """Constructs box collection.
        Args:
          data: a numpy array of shape [N, 4] representing box coordinates
        Raises:
          ValueError: if bbox data is not a numpy array
          ValueError: if invalid dimensions for bbox data
        """
        if not isinstance(data, np.ndarray):
            raise ValueError('data must be a numpy array.')
        if len(data.shape) != 2 or data.shape[1] != 4:
            raise ValueError('Invalid dimensions for box data.')
        if data.dtype != np.float32 and data.dtype != np.float64:
            raise ValueError('Invalid data type for box data: float is required.')
        if not self._is_valid_boxes(data):
            raise ValueError('Invalid box data. data must be a numpy array of '
                             'N*[y_min, x_min, y_max, x_max]')
        self.data = {'boxes': data}

    def num_boxes(self):
        """Return number of boxes held in collections."""
        return self.data['boxes'].shape[0]

    def get_extra_fields(self):
        """Return all non-box fields."""
        return [k for k in self.data.keys() if k != 'boxes']

    def has_field(self, field):
        return field in self.data

    def add_field(self, field, field_data):
        """Add data to a specified field.
        Args:
          field: a string parameter used to speficy a related field to be accessed.
          field_data: a numpy array of [N, ...] representing the data associated
              with the field.
        Raises:
          ValueError: if the field is already exist or the dimension of the field
              data does not matches the number of boxes.
        """
        if self.has_field(field):
            raise ValueError('Field ' + field + 'already exists')
        if len(field_data.shape) < 1 or field_data.shape[0] != self.num_boxes():
            raise ValueError('Invalid dimensions for field data')
        self.data[field] = field_data

    def get(self):
        """Convenience function for accesssing box coordinates.
        Returns:
          a numpy array of shape [N, 4] representing box corners
        """
        return self.get_field('boxes')

    def get_field(self, field):
        """Accesses data associated with the specified field in the box collection.
        Args:
          field: a string parameter used to speficy a related field to be accessed.
        Returns:
          a numpy 1-d array representing data of an associated field
        Raises:
          ValueError: if invalid field
        """
        if not self.has_field(field):
            raise ValueError('field {} does not exist'.format(field))
        return self.data[field]

    def get_coordinates(self):
        """Get corner coordinates of boxes.
        Returns:
         a list of 4 1-d numpy arrays [y_min, x_min, y_max, x_max]
        """
        box_coordinates = self.get()
        y_min = box_coordinates[:, 0]
        x_min = box_coordinates[:, 1]
        y_max = box_coordinates[:, 2]
        x_max = box_coordinates[:, 3]
        return [y_min, x_min, y_max, x_max]

    def _is_valid_boxes(self, data):
        """Check whether data fullfills the format of N*[ymin, xmin, ymax, xmin].
        Args:
          data: a numpy array of shape [N, 4] representing box coordinates
        Returns:
          a boolean indicating whether all ymax of boxes are equal or greater than
              ymin, and all xmax of boxes are equal or greater than xmin.
        """
        if data.shape[0] > 0:
            for i in range(data.shape[0]):
                if data[i, 0] > data[i, 2] or data[i, 1] > data[i, 3]:
                    return False
        return True


def area(boxes):
    """Computes area of boxes.

    Args:
      boxes: Numpy array with shape [N, 4] holding N boxes

    Returns:
      a numpy array with shape [N*1] representing box areas
    """
    return (boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1])


def intersection(boxes1, boxes2):
    """Compute pairwise intersection areas between boxes.

    Args:
      boxes1: a numpy array with shape [N, 4] holding N boxes
      boxes2: a numpy array with shape [M, 4] holding M boxes

    Returns:
      a numpy array with shape [N*M] representing pairwise intersection area
    """
    [y_min1, x_min1, y_max1, x_max1] = np.split(boxes1, 4, axis=1)
    [y_min2, x_min2, y_max2, x_max2] = np.split(boxes2, 4, axis=1)

    all_pairs_min_ymax = np.minimum(y_max1, np.transpose(y_max2))
    all_pairs_max_ymin = np.maximum(y_min1, np.transpose(y_min2))
    intersect_heights = np.maximum(np.zeros(all_pairs_max_ymin.shape), all_pairs_min_ymax - all_pairs_max_ymin)
    all_pairs_min_xmax = np.minimum(x_max1, np.transpose(x_max2))
    all_pairs_max_xmin = np.maximum(x_min1, np.transpose(x_min2))
    intersect_widths = np.maximum(np.zeros(all_pairs_max_xmin.shape), all_pairs_min_xmax - all_pairs_max_xmin)
    return intersect_heights * intersect_widths


def iou(boxes1, boxes2):
    """Computes pairwise intersection-over-union between box collections.

    Args:
      boxes1: a numpy array with shape [N, 4] holding N boxes.
      boxes2: a numpy array with shape [M, 4] holding N boxes.

    Returns:
      a numpy array with shape [N, M] representing pairwise iou scores.
    """
    intersect = intersection(boxes1, boxes2)
    area1 = area(boxes1)
    area2 = area(boxes2)
    union = np.expand_dims(area1, axis=1) + np.expand_dims(area2, axis=0) - intersect
    return intersect / union


def ioa(boxes1, boxes2):
    """Computes pairwise intersection-over-area between box collections.

    Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
    their intersection area over box2's area. Note that ioa is not symmetric,
    that is, IOA(box1, box2) != IOA(box2, box1).

    Args:
      boxes1: a numpy array with shape [N, 4] holding N boxes.
      boxes2: a numpy array with shape [M, 4] holding N boxes.

    Returns:
      a numpy array with shape [N, M] representing pairwise ioa scores.
    """
    intersect = intersection(boxes1, boxes2)
    areas = np.expand_dims(area(boxes2), axis=0)
    return intersect / areas


class SortOrder(object):
    """Enum class for sort order.

    Attributes:
      ascend: ascend order.
      descend: descend order.
    """
    ASCEND = 1
    DESCEND = 2


def area_boxlist(boxlist):
    """Computes area of boxes.

    Args:
      boxlist: BoxList holding N boxes

    Returns:
      a numpy array with shape [N*1] representing box areas
    """
    y_min, x_min, y_max, x_max = boxlist.get_coordinates()
    return (y_max - y_min) * (x_max - x_min)


def intersection_boxlist(boxlist1, boxlist2):
    """Compute pairwise intersection areas between boxes.

    Args:
      boxlist1: BoxList holding N boxes
      boxlist2: BoxList holding M boxes

    Returns:
      a numpy array with shape [N*M] representing pairwise intersection area
    """
    return intersection(boxlist1.get(), boxlist2.get())


def iou_boxlist(boxlist1, boxlist2):
    """Computes pairwise intersection-over-union between box collections.

    Args:
      boxlist1: BoxList holding N boxes
      boxlist2: BoxList holding M boxes

    Returns:
      a numpy array with shape [N, M] representing pairwise iou scores.
    """
    return iou(boxlist1.get(), boxlist2.get())


def ioa_boxlist(boxlist1, boxlist2):
    """Computes pairwise intersection-over-area between box collections.

    Intersection-over-area (ioa) between two boxes box1 and box2 is defined as
    their intersection area over box2's area. Note that ioa is not symmetric,
    that is, IOA(box1, box2) != IOA(box2, box1).

    Args:
      boxlist1: BoxList holding N boxes
      boxlist2: BoxList holding M boxes

    Returns:
      a numpy array with shape [N, M] representing pairwise ioa scores.
    """
    return ioa(boxlist1.get(), boxlist2.get())


def gather_boxlist(boxlist, indices, fields=None):
    """Gather boxes from BoxList according to indices and return new BoxList.

    By default, gather returns boxes corresponding to the input index list, as
    well as all additional fields stored in the boxlist (indexing into the
    first dimension).  However one can optionally only gather from a
    subset of fields.

    Args:
      boxlist: BoxList holding N boxes
      indices: a 1-d numpy array of type int_
      fields: (optional) list of fields to also gather from.  If None (default),
          all fields are gathered from.  Pass an empty fields list to only gather the box coordinates.

    Returns:
      subboxlist: a BoxList corresponding to the subset of the input BoxList specified by indices

    Raises:
      ValueError: if specified field is not contained in boxlist or if the indices are not of type int_
    """
    if indices.size:
        if np.amax(indices) >= boxlist.num_boxes() or np.amin(indices) < 0:
            raise ValueError('indices are out of valid range.')
    subboxlist = BoxList(boxlist.get()[indices, :])
    if fields is None:
        fields = boxlist.get_extra_fields()
    for field in fields:
        extra_field_data = boxlist.get_field(field)
        subboxlist.add_field(field, extra_field_data[indices, ...])
    return subboxlist


def sort_by_field_boxlist(boxlist, field, order=SortOrder.DESCEND):
    """Sort boxes and associated fields according to a scalar field.

    A common use case is reordering the boxes according to descending scores.

    Args:
        boxlist: BoxList holding N boxes.
        field: A BoxList field for sorting and reordering the BoxList.
        order: (Optional) 'descend' or 'ascend'. Default is descend.

    Returns:
      sorted_boxlist: A sorted BoxList with the field in the specified order.

    Raises:
        ValueError: if specified field does not exist or is not of single dimension.
        ValueError: if the order is not either descend or ascend.
    """
    if not boxlist.has_field(field):
        raise ValueError('Field ' + field + ' does not exist')
    if len(boxlist.get_field(field).shape) != 1:
        raise ValueError('Field ' + field + 'should be single dimension.')
    if order != SortOrder.DESCEND and order != SortOrder.ASCEND:
        raise ValueError('Invalid sort order')

    field_to_sort = boxlist.get_field(field)
    sorted_indices = np.argsort(field_to_sort)
    if order == SortOrder.DESCEND:
        sorted_indices = sorted_indices[::-1]
    return gather_boxlist(boxlist, sorted_indices)


def non_max_suppression(boxlist, max_output_size=10000, iou_threshold=1.0, score_threshold=-10.0):
    """Non maximum suppression.

    This op greedily selects a subset of detection bounding boxes, pruning
    away boxes that have high IOU (intersection over union) overlap (> thresh)
    with already selected boxes. In each iteration, the detected bounding box with
    highest score in the available pool is selected.

    Args:
        boxlist: BoxList holding N boxes.  Must contain a 'scores' field
            representing detection scores. All scores belong to the same class.
        max_output_size: maximum number of retained boxes
        iou_threshold: intersection over union threshold.
        score_threshold: minimum score threshold. Remove the boxes with scores less than
            this value. Default value is set to -10. A very low threshold to pass pretty
            much all the boxes, unless the user sets a different score threshold.

    Returns:
        a BoxList holding M boxes where M <= max_output_size
    Raises:
        ValueError: if 'scores' field does not exist
        ValueError: if threshold is not in [0, 1]
      ValueError: if max_output_size < 0
    """
    if not boxlist.has_field('scores'):
        raise ValueError('Field scores does not exist')
    if iou_threshold < 0. or iou_threshold > 1.0:
        raise ValueError('IOU threshold must be in [0, 1]')
    if max_output_size < 0:
        raise ValueError('max_output_size must be bigger than 0.')

    boxlist = filter_scores_greater_than(boxlist, score_threshold)
    if boxlist.num_boxes() == 0:
        return boxlist

    boxlist = sort_by_field_boxlist(boxlist, 'scores')

    # Prevent further computation if NMS is disabled.
    if iou_threshold == 1.0:
        if boxlist.num_boxes() > max_output_size:
            selected_indices = np.arange(max_output_size)
            return gather_boxlist(boxlist, selected_indices)
        else:
            return boxlist

    boxes = boxlist.get()
    num_boxes = boxlist.num_boxes()
    # is_index_valid is True only for all remaining valid boxes,
    is_index_valid = np.full(num_boxes, 1, dtype=bool)
    selected_indices = []
    num_output = 0
    for i in range(num_boxes):
        if num_output < max_output_size:
            if is_index_valid[i]:
                num_output += 1
                selected_indices.append(i)
                is_index_valid[i] = False
                valid_indices = np.where(is_index_valid)[0]
                if valid_indices.size == 0:
                    break

                intersect_over_union = iou(np.expand_dims(boxes[i, :], axis=0), boxes[valid_indices, :])
                intersect_over_union = np.squeeze(intersect_over_union, axis=0)
                is_index_valid[valid_indices] = np.logical_and(
                    is_index_valid[valid_indices],
                    intersect_over_union <= iou_threshold)
    return gather_boxlist(boxlist, np.array(selected_indices))


def multi_class_non_max_suppression(boxlist, score_thresh, iou_thresh, max_output_size):
    """Multi-class version of non maximum suppression.

    This op greedily selects a subset of detection bounding boxes, pruning
    away boxes that have high IOU (intersection over union) overlap (> thresh)
    with already selected boxes.  It operates independently for each class for
    which scores are provided (via the scores field of the input box_list),
    pruning boxes with score less than a provided threshold prior to
    applying NMS.

    Args:
        boxlist: BoxList holding N boxes.  Must contain a 'scores' field
            representing detection scores.  This scores field is a tensor that can
            be 1 dimensional (in the case of a single class) or 2-dimensional, which
            which case we assume that it takes the shape [num_boxes, num_classes].
            We further assume that this rank is known statically and that
            scores.shape[1] is also known (i.e., the number of classes is fixed
            and known at graph construction time).
        score_thresh: scalar threshold for score (low scoring boxes are removed).
        iou_thresh: scalar threshold for IOU (boxes that that high IOU overlap
            with previously selected boxes are removed).
        max_output_size: maximum number of retained boxes per class.

    Returns:
        a BoxList holding M boxes with a rank-1 scores field representing
            corresponding scores for each box with scores sorted in decreasing order
            and a rank-1 classes field representing a class label for each box.
    Raises:
        ValueError: if iou_thresh is not in [0, 1] or if input boxlist does not have
            a valid scores field.
    """
    if not 0 <= iou_thresh <= 1.0:
        raise ValueError('thresh must be between 0 and 1')
    if not isinstance(boxlist, BoxList):
        raise ValueError('boxlist must be a BoxList')
    if not boxlist.has_field('scores'):
        raise ValueError('input boxlist must have \'scores\' field')
    scores = boxlist.get_field('scores')
    if len(scores.shape) == 1:
        scores = np.reshape(scores, [-1, 1])
    elif len(scores.shape) == 2:
        if scores.shape[1] is None:
            raise ValueError('scores field must have statically defined second dimension')
    else:
        raise ValueError('scores field must be of rank 1 or 2')
    num_boxes = boxlist.num_boxes()
    num_scores = scores.shape[0]
    num_classes = scores.shape[1]

    if num_boxes != num_scores:
        raise ValueError('Incorrect scores field length: actual vs expected.')

    selected_boxes_list = []
    for class_idx in range(num_classes):
        boxlist_and_class_scores = BoxList(boxlist.get())
        class_scores = np.reshape(scores[0:num_scores, class_idx], [-1])
        boxlist_and_class_scores.add_field('scores', class_scores)
        boxlist_filt = filter_scores_greater_than(boxlist_and_class_scores, score_thresh)
        nms_result = non_max_suppression(
            boxlist_filt, max_output_size=max_output_size, iou_threshold=iou_thresh, score_threshold=score_thresh)
        nms_result.add_field('classes', np.zeros_like(nms_result.get_field('scores')) + class_idx)
        selected_boxes_list.append(nms_result)
    selected_boxes = concatenate_boxlist(selected_boxes_list)
    sorted_boxes = sort_by_field_boxlist(selected_boxes, 'scores')
    return sorted_boxes


def scale(boxlist, y_scale, x_scale):
    """Scale box coordinates in x and y dimensions.

    Args:
        boxlist: BoxList holding N boxes
        y_scale: float
        x_scale: float

    Returns:
        boxlist: BoxList holding N boxes
    """
    y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
    y_min = y_scale * y_min
    y_max = y_scale * y_max
    x_min = x_scale * x_min
    x_max = x_scale * x_max
    scaled_boxlist = BoxList(np.hstack([y_min, x_min, y_max, x_max]))

    fields = boxlist.get_extra_fields()
    for field in fields:
        extra_field_data = boxlist.get_field(field)
        scaled_boxlist.add_field(field, extra_field_data)

    return scaled_boxlist


def clip_to_window(boxlist, window, filter_nonoverlapping=True):
    """Clip bounding boxes to a window.

    This op clips input bounding boxes (represented by bounding box
    corners) to a window, optionally filtering out boxes that do not
    overlap at all with the window.

    Args:
        boxlist: BoxList holding M_in boxes
        window: a numpy array of shape [4] representing the [y_min, x_min, y_max, x_max]
            window to which the op should clip boxes.
        filter_nonoverlapping: whether to filter out boxes that do not overlap at all with the window.

    Returns:
        a BoxList holding M_out boxes where M_out <= M_in
    """
    y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
    win_y_min = window[0]
    win_x_min = window[1]
    win_y_max = window[2]
    win_x_max = window[3]
    y_min_clipped = np.fmax(np.fmin(y_min, win_y_max), win_y_min)
    y_max_clipped = np.fmax(np.fmin(y_max, win_y_max), win_y_min)
    x_min_clipped = np.fmax(np.fmin(x_min, win_x_max), win_x_min)
    x_max_clipped = np.fmax(np.fmin(x_max, win_x_max), win_x_min)
    clipped = BoxList(np.hstack([y_min_clipped, x_min_clipped, y_max_clipped, x_max_clipped]))
    clipped = _copy_extra_fields(clipped, boxlist)
    if filter_nonoverlapping:
        areas = area(clipped)
        nonzero_area_indices = np.reshape(np.nonzero(np.greater(areas, 0.0)), [-1]).astype(np.int32)
        clipped = gather_boxlist(clipped, nonzero_area_indices)
    return clipped


def prune_non_overlapping_boxes(boxlist1, boxlist2, minoverlap=0.0):
    """Prunes the boxes in boxlist1 that overlap less than thresh with boxlist2.

    For each box in boxlist1, we want its IOA to be more than minoverlap with
    at least one of the boxes in boxlist2. If it does not, we remove it.

    Args:
        boxlist1: BoxList holding N boxes.
        boxlist2: BoxList holding M boxes.
        minoverlap: Minimum required overlap between boxes, to count them as overlapping.

    Returns:
        A pruned boxlist with size [N', 4].
    """
    intersection_over_area = ioa(boxlist2, boxlist1)  # [M, N] tensor
    intersection_over_area = np.amax(intersection_over_area, axis=0)  # [N] tensor
    keep_bool = np.greater_equal(intersection_over_area, np.array(minoverlap))
    keep_inds = np.nonzero(keep_bool)[0]
    new_boxlist1 = gather_boxlist(boxlist1, keep_inds)
    return new_boxlist1


def prune_outside_window(boxlist, window):
    """Prunes bounding boxes that fall outside a given window.

    This function prunes bounding boxes that even partially fall outside the given
    window. See also ClipToWindow which only prunes bounding boxes that fall
    completely outside the window, and clips any bounding boxes that partially
    overflow.

    Args:
        boxlist: a BoxList holding M_in boxes.
        window: a numpy array of size 4, representing [ymin, xmin, ymax, xmax] of the window.

    Returns:
        pruned_corners: a tensor with shape [M_out, 4] where M_out <= M_in.
        valid_indices: a tensor with shape [M_out] indexing the valid bounding boxes in the input tensor.
    """

    y_min, x_min, y_max, x_max = np.array_split(boxlist.get(), 4, axis=1)
    win_y_min = window[0]
    win_x_min = window[1]
    win_y_max = window[2]
    win_x_max = window[3]
    coordinate_violations = np.hstack([
        np.less(y_min, win_y_min), np.less(x_min, win_x_min),
        np.greater(y_max, win_y_max), np.greater(x_max, win_x_max)])
    valid_indices = np.reshape(np.where(np.logical_not(np.max(coordinate_violations, axis=1))), [-1])
    return gather_boxlist(boxlist, valid_indices), valid_indices


def concatenate_boxlist(boxlists, fields=None):
    """Concatenate list of BoxLists.

    This op concatenates a list of input BoxLists into a larger BoxList.  It also
    handles concatenation of BoxList fields as long as the field tensor shapes
    are equal except for the first dimension.

    Args:
      boxlists: list of BoxList objects
      fields: optional list of fields to also concatenate.  By default, all
        fields from the first BoxList in the list are included in the concatenation.

    Returns:
      a BoxList with number of boxes equal to
        sum([boxlist.num_boxes() for boxlist in BoxList])
    Raises:
      ValueError: if boxlists is invalid (i.e., is not a list, is empty, or
        contains non BoxList objects), or if requested fields are not contained in all boxlists
    """
    if not isinstance(boxlists, list):
        raise ValueError('boxlists should be a list')
    if not boxlists:
        raise ValueError('boxlists should have nonzero length')
    for boxlist in boxlists:
        if not isinstance(boxlist, BoxList):
            raise ValueError('all elements of boxlists should be BoxList objects')
    concatenated = BoxList(np.vstack([boxlist.get() for boxlist in boxlists]))
    if fields is None:
        fields = boxlists[0].get_extra_fields()
    for field in fields:
        first_field_shape = boxlists[0].get_field(field).shape
        first_field_shape = first_field_shape[1:]
        for boxlist in boxlists:
            if not boxlist.has_field(field):
                raise ValueError('boxlist must contain all requested fields')
            field_shape = boxlist.get_field(field).shape
            field_shape = field_shape[1:]
            if field_shape != first_field_shape:
                raise ValueError('field %s must have same shape for all boxlists '
                                 'except for the 0th dimension.' % field)
        concatenated_field = np.concatenate([boxlist.get_field(field) for boxlist in boxlists], axis=0)
        concatenated.add_field(field, concatenated_field)
    return concatenated


def filter_scores_greater_than(boxlist, thresh):
    """Filter to keep only boxes with score exceeding a given threshold.

    This op keeps the collection of boxes whose corresponding scores are
    greater than the input threshold.

    Args:
      boxlist: BoxList holding N boxes.  Must contain a 'scores' field representing detection scores.
      thresh: scalar threshold

    Returns:
      a BoxList holding M boxes where M <= N

    Raises:
      ValueError: if boxlist not a BoxList object or if it does not have a scores field
    """
    if not isinstance(boxlist, BoxList):
        raise ValueError('boxlist must be a BoxList')
    if not boxlist.has_field('scores'):
        raise ValueError('input boxlist must have \'scores\' field')
    scores = boxlist.get_field('scores')
    if len(scores.shape) > 2:
        raise ValueError('Scores should have rank 1 or 2')
    if len(scores.shape) == 2 and scores.shape[1] != 1:
        raise ValueError('Scores should have rank 1 or have shape '
                         'consistent with [None, 1]')
    high_score_indices = np.reshape(np.where(np.greater(scores, thresh)), [-1]).astype(np.int32)
    return gather_boxlist(boxlist, high_score_indices)


def change_coordinate_frame(boxlist, window):
    """Change coordinate frame of the boxlist to be relative to window's frame.

    Given a window of the form [ymin, xmin, ymax, xmax],
    changes bounding box coordinates from boxlist to be relative to this window
    (e.g., the min corner maps to (0,0) and the max corner maps to (1,1)).

    An example use case is data augmentation: where we are given groundtruth
    boxes (boxlist) and would like to randomly crop the image to some
    window (window). In this case we need to change the coordinate frame of
    each groundtruth box to be relative to this new window.

    Args:
      boxlist: A BoxList object holding N boxes.
      window: a size 4 1-D numpy array.

    Returns:
      Returns a BoxList object with N boxes.
    """
    win_height = window[2] - window[0]
    win_width = window[3] - window[1]
    boxlist_new = scale(
        BoxList(boxlist.get() - [window[0], window[1], window[0], window[1]]), 1.0 / win_height, 1.0 / win_width)
    _copy_extra_fields(boxlist_new, boxlist)

    return boxlist_new


def _copy_extra_fields(boxlist_to_copy_to, boxlist_to_copy_from):
    """Copies the extra fields of boxlist_to_copy_from to boxlist_to_copy_to.

    Args:
      boxlist_to_copy_to: BoxList to which extra fields are copied.
      boxlist_to_copy_from: BoxList from which fields are copied.

    Returns:
      boxlist_to_copy_to with extra fields.
    """
    for field in boxlist_to_copy_from.get_extra_fields():
        boxlist_to_copy_to.add_field(field, boxlist_to_copy_from.get_field(field))
    return boxlist_to_copy_to


def _update_valid_indices_by_removing_high_iou_boxes(
        selected_indices, is_index_valid, intersect_over_union, threshold):
    max_iou = np.max(intersect_over_union[:, selected_indices], axis=1)
    return np.logical_and(is_index_valid, max_iou <= threshold)