File size: 3,610 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Region Similarity Calculators for BoxLists.

Region Similarity Calculators compare a pairwise measure of similarity
between the boxes in two BoxLists.
"""
import torch
from .box_list import BoxList


def area(boxlist: BoxList):
    """Computes area of boxes.

    Args:
        boxlist: BoxList holding N boxes

    Returns:
        a tensor with shape [N] representing box areas.
    """
    y_min, x_min, y_max, x_max = boxlist.boxes().chunk(4, dim=1)
    out = (y_max - y_min).squeeze(1) * (x_max - x_min).squeeze(1)
    return out


def intersection(boxlist1: BoxList, boxlist2: BoxList):
    """Compute pairwise intersection areas between boxes.

    Args:
        boxlist1: BoxList holding N boxes
        boxlist2: BoxList holding M boxes

    Returns:
        a tensor with shape [N, M] representing pairwise intersections
    """
    y_min1, x_min1, y_max1, x_max1 = boxlist1.boxes().chunk(4, dim=1)
    y_min2, x_min2, y_max2, x_max2 = boxlist2.boxes().chunk(4, dim=1)
    all_pairs_min_ymax = torch.min(y_max1, y_max2.t())
    all_pairs_max_ymin = torch.max(y_min1, y_min2.t())
    intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0)
    all_pairs_min_xmax = torch.min(x_max1, x_max2.t())
    all_pairs_max_xmin = torch.max(x_min1, x_min2.t())
    intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0)
    return intersect_heights * intersect_widths


def iou(boxlist1: BoxList, boxlist2: BoxList):
    """Computes pairwise intersection-over-union between box collections.

    Args:
        boxlist1: BoxList holding N boxes
        boxlist2: BoxList holding M boxes

    Returns:
        a tensor with shape [N, M] representing pairwise iou scores.
    """
    intersections = intersection(boxlist1, boxlist2)
    areas1 = area(boxlist1)
    areas2 = area(boxlist2)
    unions = areas1.unsqueeze(1) + areas2.unsqueeze(0) - intersections
    return torch.where(intersections == 0.0, torch.zeros_like(intersections), intersections / unions)


@torch.jit.script
class IouSimilarity(object):
    """Class to compute similarity based on Intersection over Union (IOU) metric.

    This class computes pairwise similarity between two BoxLists based on IOU.
    """
    def __init__(self):
        pass

    def compare(self, boxlist1: BoxList, boxlist2: BoxList):
        """Computes matrix of pairwise similarity between BoxLists.

        This op (to be overridden) computes a measure of pairwise similarity between
        the boxes in the given BoxLists. Higher values indicate more similarity.

        Note that this method simply measures similarity and does not explicitly
        perform a matching.

        Args:
            boxlist1: BoxList holding N boxes.
            boxlist2: BoxList holding M boxes.

        Returns:
            a (float32) tensor of shape [N, M] with pairwise similarity score.
        """
        return iou(boxlist1, boxlist2)