Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| import numpy as np | |
| import torch | |
| from mmcv.ops import RoIAlignRotated | |
| from .utils import (euclidean_distance_matrix, feature_embedding, | |
| normalize_adjacent_matrix) | |
| class LocalGraphs: | |
| """Generate local graphs for GCN to classify the neighbors of a pivot for | |
| DRRG: Deep Relational Reasoning Graph Network for Arbitrary Shape Text | |
| Detection. | |
| [https://arxiv.org/abs/2003.07493]. This code was partially adapted from | |
| https://github.com/GXYM/DRRG licensed under the MIT license. | |
| Args: | |
| k_at_hops (tuple(int)): The number of i-hop neighbors, i = 1, 2. | |
| num_adjacent_linkages (int): The number of linkages when constructing | |
| adjacent matrix. | |
| node_geo_feat_len (int): The length of embedded geometric feature | |
| vector of a text component. | |
| pooling_scale (float): The spatial scale of rotated RoI-Align. | |
| pooling_output_size (tuple(int)): The output size of rotated RoI-Align. | |
| local_graph_thr(float): The threshold for filtering out identical local | |
| graphs. | |
| """ | |
| def __init__(self, k_at_hops, num_adjacent_linkages, node_geo_feat_len, | |
| pooling_scale, pooling_output_size, local_graph_thr): | |
| assert len(k_at_hops) == 2 | |
| assert all(isinstance(n, int) for n in k_at_hops) | |
| assert isinstance(num_adjacent_linkages, int) | |
| assert isinstance(node_geo_feat_len, int) | |
| assert isinstance(pooling_scale, float) | |
| assert all(isinstance(n, int) for n in pooling_output_size) | |
| assert isinstance(local_graph_thr, float) | |
| self.k_at_hops = k_at_hops | |
| self.num_adjacent_linkages = num_adjacent_linkages | |
| self.node_geo_feat_dim = node_geo_feat_len | |
| self.pooling = RoIAlignRotated(pooling_output_size, pooling_scale) | |
| self.local_graph_thr = local_graph_thr | |
| def generate_local_graphs(self, sorted_dist_inds, gt_comp_labels): | |
| """Generate local graphs for GCN to predict which instance a text | |
| component belongs to. | |
| Args: | |
| sorted_dist_inds (ndarray): The complete graph node indices, which | |
| is sorted according to the Euclidean distance. | |
| gt_comp_labels(ndarray): The ground truth labels define the | |
| instance to which the text components (nodes in graphs) belong. | |
| Returns: | |
| pivot_local_graphs(list[list[int]]): The list of local graph | |
| neighbor indices of pivots. | |
| pivot_knns(list[list[int]]): The list of k-nearest neighbor indices | |
| of pivots. | |
| """ | |
| assert sorted_dist_inds.ndim == 2 | |
| assert (sorted_dist_inds.shape[0] == sorted_dist_inds.shape[1] == | |
| gt_comp_labels.shape[0]) | |
| knn_graph = sorted_dist_inds[:, 1:self.k_at_hops[0] + 1] | |
| pivot_local_graphs = [] | |
| pivot_knns = [] | |
| for pivot_ind, knn in enumerate(knn_graph): | |
| local_graph_neighbors = set(knn) | |
| for neighbor_ind in knn: | |
| local_graph_neighbors.update( | |
| set(sorted_dist_inds[neighbor_ind, | |
| 1:self.k_at_hops[1] + 1])) | |
| local_graph_neighbors.discard(pivot_ind) | |
| pivot_local_graph = list(local_graph_neighbors) | |
| pivot_local_graph.insert(0, pivot_ind) | |
| pivot_knn = [pivot_ind] + list(knn) | |
| if pivot_ind < 1: | |
| pivot_local_graphs.append(pivot_local_graph) | |
| pivot_knns.append(pivot_knn) | |
| else: | |
| add_flag = True | |
| for graph_ind, added_knn in enumerate(pivot_knns): | |
| added_pivot_ind = added_knn[0] | |
| added_local_graph = pivot_local_graphs[graph_ind] | |
| union = len( | |
| set(pivot_local_graph[1:]).union( | |
| set(added_local_graph[1:]))) | |
| intersect = len( | |
| set(pivot_local_graph[1:]).intersection( | |
| set(added_local_graph[1:]))) | |
| local_graph_iou = intersect / (union + 1e-8) | |
| if (local_graph_iou > self.local_graph_thr | |
| and pivot_ind in added_knn | |
| and gt_comp_labels[added_pivot_ind] | |
| == gt_comp_labels[pivot_ind] | |
| and gt_comp_labels[pivot_ind] != 0): | |
| add_flag = False | |
| break | |
| if add_flag: | |
| pivot_local_graphs.append(pivot_local_graph) | |
| pivot_knns.append(pivot_knn) | |
| return pivot_local_graphs, pivot_knns | |
| def generate_gcn_input(self, node_feat_batch, node_label_batch, | |
| local_graph_batch, knn_batch, | |
| sorted_dist_ind_batch): | |
| """Generate graph convolution network input data. | |
| Args: | |
| node_feat_batch (List[Tensor]): The batched graph node features. | |
| node_label_batch (List[ndarray]): The batched text component | |
| labels. | |
| local_graph_batch (List[List[list[int]]]): The local graph node | |
| indices of image batch. | |
| knn_batch (List[List[list[int]]]): The knn graph node indices of | |
| image batch. | |
| sorted_dist_ind_batch (list[ndarray]): The node indices sorted | |
| according to the Euclidean distance. | |
| Returns: | |
| local_graphs_node_feat (Tensor): The node features of graph. | |
| adjacent_matrices (Tensor): The adjacent matrices of local graphs. | |
| pivots_knn_inds (Tensor): The k-nearest neighbor indices in | |
| local graph. | |
| gt_linkage (Tensor): The surpervision signal of GCN for linkage | |
| prediction. | |
| """ | |
| assert isinstance(node_feat_batch, list) | |
| assert isinstance(node_label_batch, list) | |
| assert isinstance(local_graph_batch, list) | |
| assert isinstance(knn_batch, list) | |
| assert isinstance(sorted_dist_ind_batch, list) | |
| num_max_nodes = max([ | |
| len(pivot_local_graph) for pivot_local_graphs in local_graph_batch | |
| for pivot_local_graph in pivot_local_graphs | |
| ]) | |
| local_graphs_node_feat = [] | |
| adjacent_matrices = [] | |
| pivots_knn_inds = [] | |
| pivots_gt_linkage = [] | |
| for batch_ind, sorted_dist_inds in enumerate(sorted_dist_ind_batch): | |
| node_feats = node_feat_batch[batch_ind] | |
| pivot_local_graphs = local_graph_batch[batch_ind] | |
| pivot_knns = knn_batch[batch_ind] | |
| node_labels = node_label_batch[batch_ind] | |
| device = node_feats.device | |
| for graph_ind, pivot_knn in enumerate(pivot_knns): | |
| pivot_local_graph = pivot_local_graphs[graph_ind] | |
| num_nodes = len(pivot_local_graph) | |
| pivot_ind = pivot_local_graph[0] | |
| node2ind_map = {j: i for i, j in enumerate(pivot_local_graph)} | |
| knn_inds = torch.tensor( | |
| [node2ind_map[i] for i in pivot_knn[1:]]) | |
| pivot_feats = node_feats[pivot_ind] | |
| normalized_feats = node_feats[pivot_local_graph] - pivot_feats | |
| adjacent_matrix = np.zeros((num_nodes, num_nodes), | |
| dtype=np.float32) | |
| for node in pivot_local_graph: | |
| neighbors = sorted_dist_inds[node, | |
| 1:self.num_adjacent_linkages + | |
| 1] | |
| for neighbor in neighbors: | |
| if neighbor in pivot_local_graph: | |
| adjacent_matrix[node2ind_map[node], | |
| node2ind_map[neighbor]] = 1 | |
| adjacent_matrix[node2ind_map[neighbor], | |
| node2ind_map[node]] = 1 | |
| adjacent_matrix = normalize_adjacent_matrix(adjacent_matrix) | |
| pad_adjacent_matrix = torch.zeros( | |
| (num_max_nodes, num_max_nodes), | |
| dtype=torch.float, | |
| device=device) | |
| pad_adjacent_matrix[:num_nodes, :num_nodes] = torch.from_numpy( | |
| adjacent_matrix) | |
| pad_normalized_feats = torch.cat([ | |
| normalized_feats, | |
| torch.zeros( | |
| (num_max_nodes - num_nodes, normalized_feats.shape[1]), | |
| dtype=torch.float, | |
| device=device) | |
| ], | |
| dim=0) | |
| local_graph_labels = node_labels[pivot_local_graph] | |
| knn_labels = local_graph_labels[knn_inds] | |
| link_labels = ((node_labels[pivot_ind] == knn_labels) & | |
| (node_labels[pivot_ind] > 0)).astype(np.int64) | |
| link_labels = torch.from_numpy(link_labels) | |
| local_graphs_node_feat.append(pad_normalized_feats) | |
| adjacent_matrices.append(pad_adjacent_matrix) | |
| pivots_knn_inds.append(knn_inds) | |
| pivots_gt_linkage.append(link_labels) | |
| local_graphs_node_feat = torch.stack(local_graphs_node_feat, 0) | |
| adjacent_matrices = torch.stack(adjacent_matrices, 0) | |
| pivots_knn_inds = torch.stack(pivots_knn_inds, 0) | |
| pivots_gt_linkage = torch.stack(pivots_gt_linkage, 0) | |
| return (local_graphs_node_feat, adjacent_matrices, pivots_knn_inds, | |
| pivots_gt_linkage) | |
| def __call__(self, feat_maps, comp_attribs): | |
| """Generate local graphs as GCN input. | |
| Args: | |
| feat_maps (Tensor): The feature maps to extract the content | |
| features of text components. | |
| comp_attribs (ndarray): The text component attributes. | |
| Returns: | |
| local_graphs_node_feat (Tensor): The node features of graph. | |
| adjacent_matrices (Tensor): The adjacent matrices of local graphs. | |
| pivots_knn_inds (Tensor): The k-nearest neighbor indices in local | |
| graph. | |
| gt_linkage (Tensor): The surpervision signal of GCN for linkage | |
| prediction. | |
| """ | |
| assert isinstance(feat_maps, torch.Tensor) | |
| assert comp_attribs.ndim == 3 | |
| assert comp_attribs.shape[2] == 8 | |
| sorted_dist_inds_batch = [] | |
| local_graph_batch = [] | |
| knn_batch = [] | |
| node_feat_batch = [] | |
| node_label_batch = [] | |
| device = feat_maps.device | |
| for batch_ind in range(comp_attribs.shape[0]): | |
| num_comps = int(comp_attribs[batch_ind, 0, 0]) | |
| comp_geo_attribs = comp_attribs[batch_ind, :num_comps, 1:7] | |
| node_labels = comp_attribs[batch_ind, :num_comps, | |
| 7].astype(np.int32) | |
| comp_centers = comp_geo_attribs[:, 0:2] | |
| distance_matrix = euclidean_distance_matrix( | |
| comp_centers, comp_centers) | |
| batch_id = np.zeros( | |
| (comp_geo_attribs.shape[0], 1), dtype=np.float32) * batch_ind | |
| comp_geo_attribs[:, -2] = np.clip(comp_geo_attribs[:, -2], -1, 1) | |
| angle = np.arccos(comp_geo_attribs[:, -2]) * np.sign( | |
| comp_geo_attribs[:, -1]) | |
| angle = angle.reshape((-1, 1)) | |
| rotated_rois = np.hstack( | |
| [batch_id, comp_geo_attribs[:, :-2], angle]) | |
| rois = torch.from_numpy(rotated_rois).to(device) | |
| content_feats = self.pooling(feat_maps[batch_ind].unsqueeze(0), | |
| rois) | |
| content_feats = content_feats.view(content_feats.shape[0], | |
| -1).to(feat_maps.device) | |
| geo_feats = feature_embedding(comp_geo_attribs, | |
| self.node_geo_feat_dim) | |
| geo_feats = torch.from_numpy(geo_feats).to(device) | |
| node_feats = torch.cat([content_feats, geo_feats], dim=-1) | |
| sorted_dist_inds = np.argsort(distance_matrix, axis=1) | |
| pivot_local_graphs, pivot_knns = self.generate_local_graphs( | |
| sorted_dist_inds, node_labels) | |
| node_feat_batch.append(node_feats) | |
| node_label_batch.append(node_labels) | |
| local_graph_batch.append(pivot_local_graphs) | |
| knn_batch.append(pivot_knns) | |
| sorted_dist_inds_batch.append(sorted_dist_inds) | |
| (node_feats, adjacent_matrices, knn_inds, gt_linkage) = \ | |
| self.generate_gcn_input(node_feat_batch, | |
| node_label_batch, | |
| local_graph_batch, | |
| knn_batch, | |
| sorted_dist_inds_batch) | |
| return node_feats, adjacent_matrices, knn_inds, gt_linkage | |