Spaces:
Sleeping
Sleeping
# Copyright 2020 Google Research. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Matcher interface and Match class. | |
This module defines the Matcher interface and the Match object. The job of the | |
matcher is to match row and column indices based on the similarity matrix and | |
other optional parameters. Each column is matched to at most one row. There | |
are three possibilities for the matching: | |
1) match: A column matches a row. | |
2) no_match: A column does not match any row. | |
3) ignore: A column that is neither 'match' nor no_match. | |
The ignore case is regularly encountered in object detection: when an anchor has | |
a relatively small overlap with a ground-truth box, one neither wants to | |
consider this box a positive example (match) nor a negative example (no match). | |
The Match class is used to store the match results and it provides simple apis | |
to query the results. | |
""" | |
import torch | |
class Match(object): | |
"""Class to store results from the matcher. | |
This class is used to store the results from the matcher. It provides | |
convenient methods to query the matching results. | |
""" | |
def __init__(self, match_results: torch.Tensor): | |
"""Constructs a Match object. | |
Args: | |
match_results: Integer tensor of shape [N] with (1) match_results[i]>=0, | |
meaning that column i is matched with row match_results[i]. | |
(2) match_results[i]=-1, meaning that column i is not matched. | |
(3) match_results[i]=-2, meaning that column i is ignored. | |
Raises: | |
ValueError: if match_results does not have rank 1 or is not an integer int32 scalar tensor | |
""" | |
if len(match_results.shape) != 1: | |
raise ValueError('match_results should have rank 1') | |
if match_results.dtype not in (torch.int32, torch.int64): | |
raise ValueError('match_results should be an int32 or int64 scalar tensor') | |
self.match_results = match_results | |
def matched_column_indices(self): | |
"""Returns column indices that match to some row. | |
The indices returned by this op are always sorted in increasing order. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return torch.nonzero(self.match_results > -1).flatten().long() | |
def matched_column_indicator(self): | |
"""Returns column indices that are matched. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return self.match_results >= 0 | |
def num_matched_columns(self): | |
"""Returns number (int32 scalar tensor) of matched columns.""" | |
return self.matched_column_indices().numel() | |
def unmatched_column_indices(self): | |
"""Returns column indices that do not match any row. | |
The indices returned by this op are always sorted in increasing order. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return torch.nonzero(self.match_results == -1).flatten().long() | |
def unmatched_column_indicator(self): | |
"""Returns column indices that are unmatched. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return self.match_results == -1 | |
def num_unmatched_columns(self): | |
"""Returns number (int32 scalar tensor) of unmatched columns.""" | |
return self.unmatched_column_indices().numel() | |
def ignored_column_indices(self): | |
"""Returns column indices that are ignored (neither Matched nor Unmatched). | |
The indices returned by this op are always sorted in increasing order. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return torch.nonzero(self.ignored_column_indicator()).flatten().long() | |
def ignored_column_indicator(self): | |
"""Returns boolean column indicator where True means the column is ignored. | |
Returns: | |
column_indicator: boolean vector which is True for all ignored column indices. | |
""" | |
return self.match_results == -2 | |
def num_ignored_columns(self): | |
"""Returns number (int32 scalar tensor) of matched columns.""" | |
return self.ignored_column_indices().numel() | |
def unmatched_or_ignored_column_indices(self): | |
"""Returns column indices that are unmatched or ignored. | |
The indices returned by this op are always sorted in increasing order. | |
Returns: | |
column_indices: int32 tensor of shape [K] with column indices. | |
""" | |
return torch.nonzero(0 > self.match_results).flatten().long() | |
def matched_row_indices(self): | |
"""Returns row indices that match some column. | |
The indices returned by this op are ordered so as to be in correspondence with the output of | |
matched_column_indicator(). For example if self.matched_column_indicator() is [0,2], | |
and self.matched_row_indices() is [7, 3], then we know that column 0 was matched to row 7 and | |
column 2 was matched to row 3. | |
Returns: | |
row_indices: int32 tensor of shape [K] with row indices. | |
""" | |
return torch.gather(self.match_results, 0, self.matched_column_indices()).flatten().long() | |
def gather_based_on_match(self, input_tensor, unmatched_value, ignored_value): | |
"""Gathers elements from `input_tensor` based on match results. | |
For columns that are matched to a row, gathered_tensor[col] is set to input_tensor[match_results[col]]. | |
For columns that are unmatched, gathered_tensor[col] is set to unmatched_value. Finally, for columns that | |
are ignored gathered_tensor[col] is set to ignored_value. | |
Note that the input_tensor.shape[1:] must match with unmatched_value.shape | |
and ignored_value.shape | |
Args: | |
input_tensor: Tensor to gather values from. | |
unmatched_value: Constant tensor or python scalar value for unmatched columns. | |
ignored_value: Constant tensor or python scalar for ignored columns. | |
Returns: | |
gathered_tensor: A tensor containing values gathered from input_tensor. | |
The shape of the gathered tensor is [match_results.shape[0]] + input_tensor.shape[1:]. | |
""" | |
if isinstance(ignored_value, torch.Tensor): | |
input_tensor = torch.cat([ignored_value, unmatched_value, input_tensor], dim=0) | |
else: | |
# scalars | |
input_tensor = torch.cat([ | |
torch.tensor([ignored_value, unmatched_value], dtype=input_tensor.dtype, device=input_tensor.device), | |
input_tensor], dim=0) | |
gather_indices = torch.clamp(self.match_results + 2, min=0) | |
gathered_tensor = torch.index_select(input_tensor, 0, gather_indices) | |
return gathered_tensor | |