Spaces:
Sleeping
Sleeping
File size: 3,610 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# Copyright 2020 Google Research. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Region Similarity Calculators for BoxLists.
Region Similarity Calculators compare a pairwise measure of similarity
between the boxes in two BoxLists.
"""
import torch
from .box_list import BoxList
def area(boxlist: BoxList):
"""Computes area of boxes.
Args:
boxlist: BoxList holding N boxes
Returns:
a tensor with shape [N] representing box areas.
"""
y_min, x_min, y_max, x_max = boxlist.boxes().chunk(4, dim=1)
out = (y_max - y_min).squeeze(1) * (x_max - x_min).squeeze(1)
return out
def intersection(boxlist1: BoxList, boxlist2: BoxList):
"""Compute pairwise intersection areas between boxes.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
Returns:
a tensor with shape [N, M] representing pairwise intersections
"""
y_min1, x_min1, y_max1, x_max1 = boxlist1.boxes().chunk(4, dim=1)
y_min2, x_min2, y_max2, x_max2 = boxlist2.boxes().chunk(4, dim=1)
all_pairs_min_ymax = torch.min(y_max1, y_max2.t())
all_pairs_max_ymin = torch.max(y_min1, y_min2.t())
intersect_heights = torch.clamp(all_pairs_min_ymax - all_pairs_max_ymin, min=0)
all_pairs_min_xmax = torch.min(x_max1, x_max2.t())
all_pairs_max_xmin = torch.max(x_min1, x_min2.t())
intersect_widths = torch.clamp(all_pairs_min_xmax - all_pairs_max_xmin, min=0)
return intersect_heights * intersect_widths
def iou(boxlist1: BoxList, boxlist2: BoxList):
"""Computes pairwise intersection-over-union between box collections.
Args:
boxlist1: BoxList holding N boxes
boxlist2: BoxList holding M boxes
Returns:
a tensor with shape [N, M] representing pairwise iou scores.
"""
intersections = intersection(boxlist1, boxlist2)
areas1 = area(boxlist1)
areas2 = area(boxlist2)
unions = areas1.unsqueeze(1) + areas2.unsqueeze(0) - intersections
return torch.where(intersections == 0.0, torch.zeros_like(intersections), intersections / unions)
@torch.jit.script
class IouSimilarity(object):
"""Class to compute similarity based on Intersection over Union (IOU) metric.
This class computes pairwise similarity between two BoxLists based on IOU.
"""
def __init__(self):
pass
def compare(self, boxlist1: BoxList, boxlist2: BoxList):
"""Computes matrix of pairwise similarity between BoxLists.
This op (to be overridden) computes a measure of pairwise similarity between
the boxes in the given BoxLists. Higher values indicate more similarity.
Note that this method simply measures similarity and does not explicitly
perform a matching.
Args:
boxlist1: BoxList holding N boxes.
boxlist2: BoxList holding M boxes.
Returns:
a (float32) tensor of shape [N, M] with pairwise similarity score.
"""
return iou(boxlist1, boxlist2)
|