LuyangZ's picture
Upload 30 files
01df1d6 verified
'''Module for making resnet encoders.
'''
import torch
import torch.nn as nn
from cortex_DIM.nn_modules.convnet import Convnet
from cortex_DIM.nn_modules.misc import Fold, Unfold, View
_nonlin_idx = 6
class ResBlock(Convnet):
'''Residual block for ResNet
'''
def create_layers(self, shape, conv_args=None):
'''Creates layers
Args:
shape: Shape of input.
conv_args: Layer arguments for block.
'''
# Move nonlinearity to a separate step for residual.
final_nonlin = conv_args[-1][_nonlin_idx]
conv_args[-1] = list(conv_args[-1])
conv_args[-1][_nonlin_idx] = None
conv_args.append((None, 0, 0, 0, False, False, final_nonlin, None))
super().create_layers(shape, conv_args=conv_args)
if self.conv_shape != shape:
dim_x, dim_y, dim_in = shape
dim_x_, dim_y_, dim_out = self.conv_shape
stride = dim_x // dim_x_
next_x, _ = self.next_size(dim_x, dim_y, 1, stride, 0)
assert next_x == dim_x_, (self.conv_shape, shape)
self.downsample = nn.Sequential(
nn.Conv2d(dim_in, dim_out, kernel_size=1, stride=stride, padding=0, bias=False),
nn.BatchNorm2d(dim_out),
)
else:
self.downsample = None
def forward(self, x: torch.Tensor):
'''Forward pass
Args:
x: Input.
Returns:
torch.Tensor or list of torch.Tensor.
'''
if self.downsample is not None:
residual = self.downsample(x)
else:
residual = x
x = self.conv_layers[-1](self.conv_layers[:-1](x) + residual)
return x
class ResNet(Convnet):
def create_layers(self, shape, conv_before_args=None, res_args=None, conv_after_args=None, fc_args=None):
'''Creates layers
Args:
shape: Shape of the input.
conv_before_args: Arguments for convolutional layers before residuals.
res_args: Residual args.
conv_after_args: Arguments for convolutional layers after residuals.
fc_args: Fully-connected arguments.
'''
dim_x, dim_y, dim_in = shape
shape = (dim_x, dim_y, dim_in)
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)
dim_x, dim_y, dim_out = self.conv_after_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
def create_res_layers(self, shape, block_args=None):
'''Creates a set of residual blocks.
Args:
shape: input shape.
block_args: Arguments for blocks.
Returns:
nn.Sequential: sequence of residual blocks.
'''
res_layers = nn.Sequential()
block_args = block_args or []
for i, (conv_args, n_blocks) in enumerate(block_args):
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_0'.format(i), block)
for j in range(1, n_blocks):
shape = block.conv_shape
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_{}'.format(i, j), block)
shape = block.conv_shape
return res_layers, shape
def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''
if return_full_list:
conv_before_out = []
for conv_layer in self.conv_before_layers:
x = conv_layer(x)
conv_before_out.append(x)
else:
conv_before_out = self.conv_layers(x)
x = conv_before_out
if return_full_list:
res_out = []
for res_layer in self.res_layers:
x = res_layer(x)
res_out.append(x)
else:
res_out = self.res_layers(x)
x = res_out
if return_full_list:
conv_after_out = []
for conv_layer in self.conv_after_layers:
x = conv_layer(x)
conv_after_out.append(x)
else:
conv_after_out = self.conv_after_layers(x)
x = conv_after_out
x = self.reshape(x)
if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)
return conv_before_out, res_out, conv_after_out, fc_out
class FoldedResNet(ResNet):
'''Resnet with strided crop input.
'''
def create_layers(self, shape, crop_size=8, conv_before_args=None, res_args=None,
conv_after_args=None, fc_args=None):
'''Creates layers
Args:
shape: Shape of the input.
crop_size: Size of the crops.
conv_before_args: Arguments for convolutional layers before residuals.
res_args: Residual args.
conv_after_args: Arguments for convolutional layers after residuals.
fc_args: Fully-connected arguments.
'''
self.crop_size = crop_size
dim_x, dim_y, dim_in = shape
self.final_size = 2 * (dim_x // self.crop_size) - 1
self.unfold = Unfold(dim_x, self.crop_size)
self.refold = Fold(dim_x, self.crop_size)
shape = (self.crop_size, self.crop_size, dim_in)
self.conv_before_layers, self.conv_before_shape = self.create_conv_layers(shape, conv_before_args)
self.res_layers, self.res_shape = self.create_res_layers(self.conv_before_shape, res_args)
self.conv_after_layers, self.conv_after_shape = self.create_conv_layers(self.res_shape, conv_after_args)
self.conv_after_shape = self.res_shape
dim_x, dim_y, dim_out = self.conv_after_shape
dim_r = dim_x * dim_y * dim_out
self.reshape = View(-1, dim_r)
self.fc_layers, _ = self.create_linear_layers(dim_r, fc_args)
def create_res_layers(self, shape, block_args=None):
'''Creates a set of residual blocks.
Args:
shape: input shape.
block_args: Arguments for blocks.
Returns:
nn.Sequential: sequence of residual blocks.
'''
res_layers = nn.Sequential()
block_args = block_args or []
for i, (conv_args, n_blocks) in enumerate(block_args):
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_0'.format(i), block)
for j in range(1, n_blocks):
shape = block.conv_shape
block = ResBlock(shape, conv_args=conv_args)
res_layers.add_module('block_{}_{}'.format(i, j), block)
shape = block.conv_shape
dim_x, dim_y = shape[:2]
if dim_x != dim_y:
raise ValueError('dim_x and dim_y do not match.')
if dim_x == 1:
shape = (self.final_size, self.final_size, shape[2])
return res_layers, shape
def forward(self, x: torch.Tensor, return_full_list=False):
'''Forward pass
Args:
x: Input.
return_full_list: Optional, returns all layer outputs.
Returns:
torch.Tensor or list of torch.Tensor.
'''
x = self.unfold(x)
conv_before_out = []
for conv_layer in self.conv_before_layers:
x = conv_layer(x)
if x.size(2) == 1:
x = self.refold(x)
conv_before_out.append(x)
res_out = []
for res_layer in self.res_layers:
x = res_layer(x)
res_out.append(x)
if x.size(2) == 1:
x = self.refold(x)
res_out[-1] = x
conv_after_out = []
for conv_layer in self.conv_after_layers:
x = conv_layer(x)
if x.size(2) == 1:
x = self.refold(x)
conv_after_out.append(x)
x = self.reshape(x)
if return_full_list:
fc_out = []
for fc_layer in self.fc_layers:
x = fc_layer(x)
fc_out.append(x)
else:
fc_out = self.fc_layers(x)
if not return_full_list:
conv_before_out = conv_before_out[-1]
res_out = res_out[-1]
conv_after_out = conv_after_out[-1]
return conv_before_out, res_out, conv_after_out, fc_out