|
|
|
import math |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
|
|
from backbones.classifier import FracClassifier |
|
|
|
|
|
class ContBatchNorm3d(nn.modules.batchnorm._BatchNorm): |
|
def _check_input_dim(self, input): |
|
|
|
if input.dim() != 5: |
|
raise ValueError('expected 5D input (got {}D input)'.format(input.dim())) |
|
|
|
def forward(self, input): |
|
self._check_input_dim(input) |
|
return F.batch_norm( |
|
input, self.running_mean, self.running_var, self.weight, self.bias, |
|
True, self.momentum, self.eps) |
|
|
|
|
|
class LUConv(nn.Module): |
|
def __init__(self, in_chan, out_chan, act): |
|
super(LUConv, self).__init__() |
|
self.conv1 = nn.Conv3d(in_chan, out_chan, kernel_size=3, padding=1) |
|
self.bn1 = ContBatchNorm3d(out_chan) |
|
|
|
if act == 'relu': |
|
self.activation = nn.ReLU(out_chan) |
|
elif act == 'prelu': |
|
self.activation = nn.PReLU(out_chan) |
|
elif act == 'elu': |
|
self.activation = nn.ELU(inplace=True) |
|
else: |
|
raise |
|
|
|
def forward(self, x): |
|
out = self.activation(self.bn1(self.conv1(x))) |
|
return out |
|
|
|
|
|
def _make_nConv(in_channel, depth, act, double_chnnel=False): |
|
if double_chnnel: |
|
layer1 = LUConv(in_channel, 32 * (2 ** (depth+1)),act) |
|
layer2 = LUConv(32 * (2 ** (depth+1)), 32 * (2 ** (depth+1)),act) |
|
else: |
|
layer1 = LUConv(in_channel, 32*(2**depth),act) |
|
layer2 = LUConv(32*(2**depth), 32*(2**depth)*2,act) |
|
|
|
return nn.Sequential(layer1,layer2) |
|
|
|
class DownTransition(nn.Module): |
|
def __init__(self, in_channel,depth, act): |
|
super(DownTransition, self).__init__() |
|
self.ops = _make_nConv(in_channel, depth,act) |
|
self.maxpool = nn.MaxPool3d(2) |
|
self.current_depth = depth |
|
|
|
def forward(self, x): |
|
if self.current_depth == 3: |
|
out = self.ops(x) |
|
out_before_pool = out |
|
else: |
|
out_before_pool = self.ops(x) |
|
out = self.maxpool(out_before_pool) |
|
return out, out_before_pool |
|
|
|
class UpTransition(nn.Module): |
|
def __init__(self, inChans, outChans, depth,act): |
|
super(UpTransition, self).__init__() |
|
self.depth = depth |
|
self.up_conv = nn.ConvTranspose3d(inChans, outChans, kernel_size=2, stride=2) |
|
self.ops = _make_nConv(inChans+ outChans//2,depth, act, double_chnnel=True) |
|
|
|
def forward(self, x, skip_x): |
|
out_up_conv = self.up_conv(x) |
|
concat = torch.cat((out_up_conv,skip_x),1) |
|
out = self.ops(concat) |
|
return out |
|
|
|
|
|
class OutputTransition(nn.Module): |
|
def __init__(self, inChans, n_labels): |
|
|
|
super(OutputTransition, self).__init__() |
|
self.final_conv = nn.Conv3d(inChans, n_labels, kernel_size=1) |
|
|
|
|
|
def forward(self, x): |
|
out = torch.sigmoid(self.final_conv(x)) |
|
return out |
|
|
|
class UNet3D(nn.Module): |
|
|
|
|
|
def __init__(self, input_size, n_class=1, act='relu', in_channels=1): |
|
super(UNet3D, self).__init__() |
|
|
|
self.down_tr64 = DownTransition(in_channels,0,act) |
|
self.down_tr128 = DownTransition(64,1,act) |
|
self.down_tr256 = DownTransition(128,2,act) |
|
self.down_tr512 = DownTransition(256,3,act) |
|
|
|
|
|
self.classifier = FracClassifier(encoder_channels=512, final_channels=n_class, linear_kernel=int(math.pow(input_size / 32, 3) * 512)) |
|
|
|
def forward(self, x): |
|
self.out64, _ = self.down_tr64(x) |
|
self.out128, _ = self.down_tr128(self.out64) |
|
self.out256, _ = self.down_tr256(self.out128) |
|
self.out512, _ = self.down_tr512(self.out256) |
|
|
|
self.out = self.classifier(self.out512) |
|
|
|
return self.out |