PLA-Net / gcn_lib /dense /torch_nn.py
juliocesar-io's picture
Added initial app
799e642
raw
history blame
3.14 kB
import torch
from torch import nn
from torch.nn import Sequential as Seq, Linear as Lin, Conv2d
##############################
# Basic layers
##############################
def act_layer(act, inplace=False, neg_slope=0.2, n_prelu=1):
# activation layer
act = act.lower()
if act == 'relu':
layer = nn.ReLU(inplace)
elif act == 'leakyrelu':
layer = nn.LeakyReLU(neg_slope, inplace)
elif act == 'prelu':
layer = nn.PReLU(num_parameters=n_prelu, init=neg_slope)
else:
raise NotImplementedError('activation layer [%s] is not found' % act)
return layer
def norm_layer(norm, nc):
# normalization layer 2d
norm = norm.lower()
if norm == 'batch':
layer = nn.BatchNorm2d(nc, affine=True)
elif norm == 'instance':
layer = nn.InstanceNorm2d(nc, affine=False)
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm)
return layer
class MLP(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True):
m = []
for i in range(1, len(channels)):
m.append(Lin(channels[i - 1], channels[i], bias))
if act is not None and act.lower() != 'none':
m.append(act_layer(act))
if norm is not None and norm.lower() != 'none':
m.append(norm_layer(norm, channels[-1]))
super(MLP, self).__init__(*m)
class BasicConv(Seq):
def __init__(self, channels, act='relu', norm=None, bias=True, drop=0.):
m = []
for i in range(1, len(channels)):
m.append(Conv2d(channels[i - 1], channels[i], 1, bias=bias))
if act is not None and act.lower() != 'none':
m.append(act_layer(act))
if norm is not None and norm.lower() != 'none':
m.append(norm_layer(norm, channels[-1]))
if drop > 0:
m.append(nn.Dropout2d(drop))
super(BasicConv, self).__init__(*m)
self.reset_parameters()
def reset_parameters(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
elif isinstance(m, nn.BatchNorm2d) or isinstance(m, nn.InstanceNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
def batched_index_select(inputs, index):
"""
:param inputs: torch.Size([batch_size, num_dims, num_vertices, 1])
:param index: torch.Size([batch_size, num_vertices, k])
:return: torch.Size([batch_size, num_dims, num_vertices, k])
"""
batch_size, num_dims, num_vertices, _ = inputs.shape
k = index.shape[2]
idx = torch.arange(0, batch_size) * num_vertices
idx = idx.view(batch_size, -1)
inputs = inputs.transpose(2, 1).contiguous().view(-1, num_dims)
index = index.view(batch_size, -1) + idx.type(index.dtype).to(inputs.device)
index = index.view(-1)
return torch.index_select(inputs, 0, index).view(batch_size, -1, num_dims).transpose(2, 1).view(batch_size, num_dims, -1, k)