Spaces:
Runtime error
Runtime error
'''Convnet encoder module. | |
''' | |
import torch | |
import torch.nn as nn | |
#from cortex.built_ins.networks.utils import get_nonlinearity | |
from cortex_DIM.nn_modules.misc import Fold, Unfold, View | |
def infer_conv_size(w, k, s, p): | |
'''Infers the next size after convolution. | |
Args: | |
w: Input size. | |
k: Kernel size. | |
s: Stride. | |
p: Padding. | |
Returns: | |
int: Output size. | |
''' | |
x = (w - k + 2 * p) // s + 1 | |
return x | |
class Convnet(nn.Module): | |
'''Basic convnet convenience class. | |
Attributes: | |
conv_layers: nn.Sequential of nn.Conv2d layers with batch norm, | |
dropout, nonlinearity. | |
fc_layers: nn.Sequential of nn.Linear layers with batch norm, | |
dropout, nonlinearity. | |
reshape: Simple reshape layer. | |
conv_shape: Shape of the convolutional output. | |
''' | |
def __init__(self, *args, **kwargs): | |
super().__init__() | |
self.create_layers(*args, **kwargs) | |
def create_layers(self, shape, conv_args=None, fc_args=None): | |
'''Creates layers | |
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) | |
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) | |
Args: | |
shape: Shape of input. | |
conv_args: List of tuple of convolutional arguments. | |
fc_args: List of tuple of fully-connected arguments. | |
''' | |
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) | |
dim_x, dim_y, dim_out = self.conv_shape | |
dim_r = dim_x * dim_y * dim_out | |
self.reshape = View(-1, dim_r) | |
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
def create_conv_layers(self, shape, conv_args): | |
'''Creates a set of convolutional layers. | |
Args: | |
shape: Input shape. | |
conv_args: List of tuple of convolutional arguments. | |
Returns: | |
nn.Sequential: a sequence of convolutional layers. | |
''' | |
conv_layers = nn.Sequential() | |
conv_args = conv_args or [] | |
dim_x, dim_y, dim_in = shape | |
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): | |
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
conv_block = nn.Sequential() | |
if dim_out is not None: | |
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) | |
conv_block.add_module(name + 'conv', conv) | |
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) | |
else: | |
dim_out = dim_in | |
if dropout: | |
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) | |
if batch_norm: | |
bn = nn.BatchNorm2d(dim_out) | |
conv_block.add_module(name + 'bn', bn) | |
if nonlinearity: | |
nonlinearity = get_nonlinearity(nonlinearity) | |
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
if pool: | |
(pool_type, kernel, stride) = pool | |
Pool = getattr(nn, pool_type) | |
conv_block.add_module(name + 'pool', Pool(kernel_size=kernel, stride=stride)) | |
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) | |
conv_layers.add_module(name, conv_block) | |
dim_in = dim_out | |
dim_out = dim_in | |
return conv_layers, (dim_x, dim_y, dim_out) | |
def create_linear_layers(self, dim_in, fc_args): | |
''' | |
Args: | |
dim_in: Number of input units. | |
fc_args: List of tuple of fully-connected arguments. | |
Returns: | |
nn.Sequential. | |
''' | |
fc_layers = nn.Sequential() | |
fc_args = fc_args or [] | |
for i, (dim_out, batch_norm, dropout, nonlinearity) in enumerate(fc_args): | |
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
fc_block = nn.Sequential() | |
if dim_out is not None: | |
fc_block.add_module(name + 'fc', nn.Linear(dim_in, dim_out)) | |
else: | |
dim_out = dim_in | |
if dropout: | |
fc_block.add_module(name + 'do', nn.Dropout(p=dropout)) | |
if batch_norm: | |
bn = nn.BatchNorm1d(dim_out) | |
fc_block.add_module(name + 'bn', bn) | |
if nonlinearity: | |
nonlinearity = get_nonlinearity(nonlinearity) | |
fc_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
fc_layers.add_module(name, fc_block) | |
dim_in = dim_out | |
return fc_layers, dim_in | |
def next_size(self, dim_x, dim_y, k, s, p): | |
'''Infers the next size of a convolutional layer. | |
Args: | |
dim_x: First dimension. | |
dim_y: Second dimension. | |
k: Kernel size. | |
s: Stride. | |
p: Padding. | |
Returns: | |
(int, int): (First output dimension, Second output dimension) | |
''' | |
if isinstance(k, int): | |
kx, ky = (k, k) | |
else: | |
kx, ky = k | |
if isinstance(s, int): | |
sx, sy = (s, s) | |
else: | |
sx, sy = s | |
if isinstance(p, int): | |
px, py = (p, p) | |
else: | |
px, py = p | |
return (infer_conv_size(dim_x, kx, sx, px), | |
infer_conv_size(dim_y, ky, sy, py)) | |
def forward(self, x: torch.Tensor, return_full_list=False): | |
'''Forward pass | |
Args: | |
x: Input. | |
return_full_list: Optional, returns all layer outputs. | |
Returns: | |
torch.Tensor or list of torch.Tensor. | |
''' | |
if return_full_list: | |
conv_out = [] | |
for conv_layer in self.conv_layers: | |
x = conv_layer(x) | |
conv_out.append(x) | |
else: | |
conv_out = self.conv_layers(x) | |
x = conv_out | |
x = self.reshape(x) | |
if return_full_list: | |
fc_out = [] | |
for fc_layer in self.fc_layers: | |
x = fc_layer(x) | |
fc_out.append(x) | |
else: | |
fc_out = self.fc_layers(x) | |
return conv_out, fc_out | |
class FoldedConvnet(Convnet): | |
'''Convnet with strided crop input. | |
''' | |
def create_layers(self, shape, crop_size=8, conv_args=None, fc_args=None): | |
'''Creates layers | |
conv_args are in format (dim_h, f_size, stride, pad, batch_norm, dropout, nonlinearity, pool) | |
fc_args are in format (dim_h, batch_norm, dropout, nonlinearity) | |
Args: | |
shape: Shape of input. | |
crop_size: Size of crops | |
conv_args: List of tuple of convolutional arguments. | |
fc_args: List of tuple of fully-connected arguments. | |
''' | |
self.crop_size = crop_size | |
dim_x, dim_y, dim_in = shape | |
if dim_x != dim_y: | |
raise ValueError('x and y dimensions must be the same to use Folded encoders.') | |
self.final_size = 2 * (dim_x // self.crop_size) - 1 | |
self.unfold = Unfold(dim_x, self.crop_size) | |
self.refold = Fold(dim_x, self.crop_size) | |
shape = (self.crop_size, self.crop_size, dim_in) | |
self.conv_layers, self.conv_shape = self.create_conv_layers(shape, conv_args) | |
dim_x, dim_y, dim_out = self.conv_shape | |
dim_r = dim_x * dim_y * dim_out | |
self.reshape = View(-1, dim_r) | |
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args) | |
def create_conv_layers(self, shape, conv_args): | |
'''Creates a set of convolutional layers. | |
Args: | |
shape: Input shape. | |
conv_args: List of tuple of convolutional arguments. | |
Returns: | |
nn.Sequential: A sequence of convolutional layers. | |
''' | |
conv_layers = nn.Sequential() | |
conv_args = conv_args or [] | |
dim_x, dim_y, dim_in = shape | |
for i, (dim_out, f, s, p, batch_norm, dropout, nonlinearity, pool) in enumerate(conv_args): | |
name = '({}/{})_{}'.format(dim_in, dim_out, i + 1) | |
conv_block = nn.Sequential() | |
if dim_out is not None: | |
conv = nn.Conv2d(dim_in, dim_out, kernel_size=f, stride=s, padding=p, bias=not(batch_norm)) | |
conv_block.add_module(name + 'conv', conv) | |
dim_x, dim_y = self.next_size(dim_x, dim_y, f, s, p) | |
else: | |
dim_out = dim_in | |
if dropout: | |
conv_block.add_module(name + 'do', nn.Dropout2d(p=dropout)) | |
if batch_norm: | |
bn = nn.BatchNorm2d(dim_out) | |
conv_block.add_module(name + 'bn', bn) | |
if nonlinearity: | |
nonlinearity = get_nonlinearity(nonlinearity) | |
conv_block.add_module(nonlinearity.__class__.__name__, nonlinearity) | |
if pool: | |
(pool_type, kernel, stride) = pool | |
Pool = getattr(nn, pool_type) | |
conv_block.add_module('pool', Pool(kernel_size=kernel, stride=stride)) | |
dim_x, dim_y = self.next_size(dim_x, dim_y, kernel, stride, 0) | |
conv_layers.add_module(name, conv_block) | |
dim_in = dim_out | |
if dim_x != dim_y: | |
raise ValueError('dim_x and dim_y do not match.') | |
if dim_x == 1: | |
dim_x = self.final_size | |
dim_y = self.final_size | |
dim_out = dim_in | |
return conv_layers, (dim_x, dim_y, dim_out) | |
def forward(self, x: torch.Tensor, return_full_list=False): | |
'''Forward pass | |
Args: | |
x: Input. | |
return_full_list: Optional, returns all layer outputs. | |
Returns: | |
torch.Tensor or list of torch.Tensor. | |
''' | |
x = self.unfold(x) | |
conv_out = [] | |
for conv_layer in self.conv_layers: | |
x = conv_layer(x) | |
if x.size(2) == 1: | |
x = self.refold(x) | |
conv_out.append(x) | |
x = self.reshape(x) | |
if return_full_list: | |
fc_out = [] | |
for fc_layer in self.fc_layers: | |
x = fc_layer(x) | |
fc_out.append(x) | |
else: | |
fc_out = self.fc_layers(x) | |
if not return_full_list: | |
conv_out = conv_out[-1] | |
return conv_out, fc_out | |