import torch from torch import nn from torch.nn import Sequential as Seq, Linear as Lin, Conv2d ############################## # Basic layers ############################## def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1): # activation layer act = act.lower() if act == 'relu': layer = nn.ReLU(inplace) elif act == 'leakyrelu': layer = nn.LeakyReLU(neg_slope, inplace) elif act == 'prelu': layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope) else: raise NotImplementedError('activation layer [%s] is not found' % act) return layer def norm_layer(norm, nc): # normalization layer 2d norm = norm.lower() if norm == 'batch': layer = nn.BatchNorm2d(nc, affine=True) elif norm == 'instance': layer = nn.InstanceNorm2d(nc, affine=False) else: raise NotImplementedError('normalization layer [%s] is not found' % norm) return layer class MLP(Seq): def __init__(self, channels, act='relu', norm=None, bias=True): m = [] for i in range(1, len(channels)): m.append(Lin(channels[i - 1], channels[i], bias)) if act is not None and act.lower() != 'none': m.append(act_layer(act)) if norm is not None and norm.lower() != 'none': m.append(norm_layer(norm, channels[-1])) super(MLP, self).__init__(*m) class BasicConv(Seq): def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.): m = [] for i in range(1, len(channels)): m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias)) if act is not None and act.lower() != 'none': m.append(act_layer(act)) if norm is not None and norm.lower() != 'none': m.append(norm_layer(norm, channels[-1])) if drop > 0: m.append(nn.Dropout2d(drop)) super(BasicConv, self).__init__(*m) self.reset_parameters() def reset_parameters(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight) if m.bias is not None: nn.init.zeros_(m.bias) elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def batched_index_select(inputs, index): """ :param inputs: torch.Size([batch_size, num_dims, num_vertices, 1]) :param index: torch.Size([batch_size, num_vertices, k]) :return: torch.Size([batch_size, num_dims, num_vertices, k]) """ batch_size, num_dims, num_vertices, _ = inputs.shape k = index.shape[2] idx = torch.arange(0, batch_size) * num_vertices idx = idx.view(batch_size, -1) inputs = inputs.transpose(2, 1).contiguous().view(-1, num_dims) index = index.view(batch_size, -1) + idx.type(index.dtype).to(inputs.device) index = index.view(-1) return torch.index_select(inputs, 0, index).view(batch_size, -1, num_dims).transpose(2, 1).view(batch_size, num_dims, -1, k)