'''Convnet encoder module. ''' import torch import torch.nn as nn #from cortex.built_ins.networks.utils import get_nonlinearity from cortex_DIM.nn_modules.misc import Fold, Unfold, View def infer_conv_size(w, k, s, p): '''Infers the next size after convolution. Args: w: Input size. k: Kernel size. s: Stride. p: Padding. Returns: int: Output size. ''' x = (w - k + 2 * p) // s + 1 return x class Convnet(nn.Module): '''Basic convnet convenience class. Attributes: conv_layers: nn.Sequential of nn.Conv2d layers with batch norm, dropout, nonlinearity. fc_layers: nn.Sequential of nn.Linear layers with batch norm, dropout, nonlinearity. reshape: Simple reshape layer. conv_shape: Shape of the convolutional output. ''' def __init__(self, *args, **kwargs): super().__init__() self.create_layers(*args, **kwargs) def create_layers(self, shape, conv_args=None, fc_args=None): '''Creates layers conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) Args: shape: Shape of input. conv_args: List of tuple of convolutional arguments. fc_args: List of tuple of fully-connected arguments. ''' self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) dim_x, dim_y, dim_out = self.conv_shape dim_r = dim_x * dim_y * dim_out self.reshape = View(-1, dim_r) self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) def create_conv_layers(self, shape, conv_args): '''Creates a set of convolutional layers. Args: shape: Input shape. conv_args: List of tuple of convolutional arguments. Returns: nn.Sequential: a sequence of convolutional layers. ''' conv_layers = nn.Sequential() conv_args = conv_args or [] dim_x, dim_y, dim_in = shape for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) conv_block = nn.Sequential() if dim_out is not None: conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) conv_block.add_module(name + 'conv', conv) dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) else: dim_out = dim_in if dropout: conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) if batch_norm: bn = nn.BatchNorm2d(dim_out) conv_block.add_module(name + 'bn', bn) if nonlinearity: nonlinearity = get_nonlinearity(nonlinearity) conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) if pool: (pool_type, kernel, stride) = pool Pool = getattr(nn, pool_type) conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride)) dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) conv_layers.add_module(name, conv_block) dim_in = dim_out dim_out = dim_in return conv_layers, (dim_x, dim_y, dim_out) def create_linear_layers(self, dim_in, fc_args): ''' Args: dim_in: Number of input units. fc_args: List of tuple of fully-connected arguments. Returns: nn.Sequential. ''' fc_layers = nn.Sequential() fc_args = fc_args or [] for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args): name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) fc_block = nn.Sequential() if dim_out is not None: fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out)) else: dim_out = dim_in if dropout: fc_block.add_module(name + 'do', nn.Dropout(p=dropout)) if batch_norm: bn = nn.BatchNorm1d(dim_out) fc_block.add_module(name + 'bn', bn) if nonlinearity: nonlinearity = get_nonlinearity(nonlinearity) fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity) fc_layers.add_module(name, fc_block) dim_in = dim_out return fc_layers, dim_in def next_size(self, dim_x, dim_y, k, s, p): '''Infers the next size of a convolutional layer. Args: dim_x: First dimension. dim_y: Second dimension. k: Kernel size. s: Stride. p: Padding. Returns: (int, int): (First output dimension, Second output dimension) ''' if isinstance(k, int): kx, ky = (k, k) else: kx, ky = k if isinstance(s, int): sx, sy = (s, s) else: sx, sy = s if isinstance(p, int): px, py = (p, p) else: px, py = p return (infer_conv_size(dim_x, kx, sx, px), infer_conv_size(dim_y, ky, sy, py)) def forward(self, x: torch.Tensor, return_full_list=False): '''Forward pass Args: x: Input. return_full_list: Optional, returns all layer outputs. Returns: torch.Tensor or list of torch.Tensor. ''' if return_full_list: conv_out = [] for conv_layer in self.conv_layers: x = conv_layer(x) conv_out.append(x) else: conv_out = self.conv_layers(x) x = conv_out x = self.reshape(x) if return_full_list: fc_out = [] for fc_layer in self.fc_layers: x = fc_layer(x) fc_out.append(x) else: fc_out = self.fc_layers(x) return conv_out, fc_out class FoldedConvnet(Convnet): '''Convnet with strided crop input. ''' def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None): '''Creates layers conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) Args: shape: Shape of input. crop_size: Size of crops conv_args: List of tuple of convolutional arguments. fc_args: List of tuple of fully-connected arguments. ''' self.crop_size = crop_size dim_x, dim_y, dim_in = shape if dim_x != dim_y: raise ValueError('x and y dimensions must be the same to use Folded encoders.') self.final_size = 2 * (dim_x // self.crop_size) - 1 self.unfold = Unfold(dim_x, self.crop_size) self.refold = Fold(dim_x, self.crop_size) shape = (self.crop_size, self.crop_size, dim_in) self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) dim_x, dim_y, dim_out = self.conv_shape dim_r = dim_x * dim_y * dim_out self.reshape = View(-1, dim_r) self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) def create_conv_layers(self, shape, conv_args): '''Creates a set of convolutional layers. Args: shape: Input shape. conv_args: List of tuple of convolutional arguments. Returns: nn.Sequential: A sequence of convolutional layers. ''' conv_layers = nn.Sequential() conv_args = conv_args or [] dim_x, dim_y, dim_in = shape for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) conv_block = nn.Sequential() if dim_out is not None: conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) conv_block.add_module(name + 'conv', conv) dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) else: dim_out = dim_in if dropout: conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) if batch_norm: bn = nn.BatchNorm2d(dim_out) conv_block.add_module(name + 'bn', bn) if nonlinearity: nonlinearity = get_nonlinearity(nonlinearity) conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) if pool: (pool_type, kernel, stride) = pool Pool = getattr(nn, pool_type) conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride)) dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) conv_layers.add_module(name, conv_block) dim_in = dim_out if dim_x != dim_y: raise ValueError('dim_x and dim_y do not match.') if dim_x == 1: dim_x = self.final_size dim_y = self.final_size dim_out = dim_in return conv_layers, (dim_x, dim_y, dim_out) def forward(self, x: torch.Tensor, return_full_list=False): '''Forward pass Args: x: Input. return_full_list: Optional, returns all layer outputs. Returns: torch.Tensor or list of torch.Tensor. ''' x = self.unfold(x) conv_out = [] for conv_layer in self.conv_layers: x = conv_layer(x) if x.size(2) == 1: x = self.refold(x) conv_out.append(x) x = self.reshape(x) if return_full_list: fc_out = [] for fc_layer in self.fc_layers: x = fc_layer(x) fc_out.append(x) else: fc_out = self.fc_layers(x) if not return_full_list: conv_out = conv_out[-1] return conv_out, fc_out