# Copyright 2020 Google Research. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Bounding Box List definition. BoxList represents a list of bounding boxes as tensorflow tensors, where each bounding box is represented as a row of 4 numbers, [y_min, x_min, y_max, x_max]. It is assumed that all bounding boxes within a given list correspond to a single image. See also box_list.py for common box related operations (such as area, iou, etc). Optionally, users can add additional related fields (such as weights). We assume the following things to be true about fields: * they correspond to boxes in the box_list along the 0th dimension * they have inferable rank at graph construction time * all dimensions except for possibly the 0th can be inferred (i.e., not None) at graph construction time. Some other notes: * Following tensorflow conventions, we use height, width ordering, and correspondingly, y,x (or ymin, xmin, ymax, xmax) ordering * Tensors are always provided as (flat) [N, 4] tensors. """ import torch from typing import Optional, List, Dict @torch.jit.script class BoxList(object): """Box collection.""" data: Dict[str, torch.Tensor] def __init__(self, boxes): """Constructs box collection. Args: boxes: a tensor of shape [N, 4] representing box corners Raises: ValueError: if invalid dimensions for bbox data or if bbox data is not in float32 format. """ if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError('Invalid dimensions for box data.') if boxes.dtype != torch.float32: raise ValueError('Invalid tensor type: should be tf.float32') self.data = {'boxes': boxes} def num_boxes(self): """Returns number of boxes held in collection. Returns: a tensor representing the number of boxes held in the collection. """ return self.data['boxes'].shape[0] def get_all_fields(self): """Returns all fields.""" return self.data.keys() def get_extra_fields(self): """Returns all non-box fields (i.e., everything not named 'boxes').""" # return [k for k in self.data.keys() if k != 'boxes'] # FIXME torscript doesn't support comprehensions yet extra: List[str] = [] for k in self.data.keys(): if k != 'boxes': extra.append(k) return extra def add_field(self, field: str, field_data: torch.Tensor): """Add field to box list. This method can be used to add related box data such as weights/labels, etc. Args: field: a string key to access the data via `get` field_data: a tensor containing the data to store in the BoxList """ self.data[field] = field_data def has_field(self, field: str): return field in self.data #@property # FIXME for torchscript compat def boxes(self): """Convenience function for accessing box coordinates. Returns: a tensor with shape [N, 4] representing box coordinates. """ return self.get_field('boxes') #@boxes.setter # FIXME for torchscript compat def set_boxes(self, boxes): """Convenience function for setting box coordinates. Args: boxes: a tensor of shape [N, 4] representing box corners Raises: ValueError: if invalid dimensions for bbox data """ if len(boxes.shape) != 2 or boxes.shape[-1] != 4: raise ValueError('Invalid dimensions for box data.') self.data['boxes'] = boxes def get_field(self, field: str): """Accesses a box collection and associated fields. This function returns specified field with object; if no field is specified, it returns the box coordinates. Args: field: this optional string parameter can be used to specify a related field to be accessed. Returns: a tensor representing the box collection or an associated field. Raises: ValueError: if invalid field """ if not self.has_field(field): raise ValueError(f'field {field} does not exist') return self.data[field] def set_field(self, field: str, value: torch.Tensor): """Sets the value of a field. Updates the field of a box_list with a given value. Args: field: (string) name of the field to set value. value: the value to assign to the field. Raises: ValueError: if the box_list does not have specified field. """ if not self.has_field(field): raise ValueError(f'field {field} does not exist') self.data[field] = value def get_center_coordinates_and_sizes(self): """Computes the center coordinates, height and width of the boxes. Returns: a list of 4 1-D tensors [ycenter, xcenter, height, width]. """ box_corners = self.boxes() ymin, xmin, ymax, xmax = box_corners.t().unbind() width = xmax - xmin height = ymax - ymin ycenter = ymin + height / 2. xcenter = xmin + width / 2. return [ycenter, xcenter, height, width] def transpose_coordinates(self): """Transpose the coordinate representation in a boxlist. """ y_min, x_min, y_max, x_max = self.boxes().chunk(4, dim=1) self.set_boxes(torch.cat([x_min, y_min, x_max, y_max], 1)) def as_tensor_dict(self, fields: Optional[List[str]] = None): """Retrieves specified fields as a dictionary of tensors. Args: fields: (optional) list of fields to return in the dictionary. If None (default), all fields are returned. Returns: tensor_dict: A dictionary of tensors specified by fields. Raises: ValueError: if specified field is not contained in boxlist. """ tensor_dict = {} if fields is None: fields = self.get_all_fields() for field in fields: if not self.has_field(field): raise ValueError('boxlist must contain all specified fields') tensor_dict[field] = self.get_field(field) return tensor_dict #@property def device(self): return self.data['boxes'].device