LuyangZ's picture
Upload 30 files
01df1d6 verified
'''Basic cortex_DIM encoder.
'''
import torch
from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet
from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet
def create_encoder(Module):
class Encoder(Module):
'''Encoder used for cortex_DIM.
'''
def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs):
'''
Args:
args: Arguments for parent class.
local_idx: Index in list of convolutional layers for local features.
multi_idx: Index in list of convolutional layers for multiple globals.
conv_idx: Index in list of convolutional layers for intermediate features.
fc_idx: Index in list of fully-connected layers for intermediate features.
kwargs: Keyword arguments for the parent class.
'''
super().__init__(*args, **kwargs)
if local_idx is None:
raise ValueError('`local_idx` must be set')
conv_idx = conv_idx or local_idx
self.local_idx = local_idx
self.multi_idx = multi_idx
self.conv_idx = conv_idx
self.fc_idx = fc_idx
def forward(self, x: torch.Tensor):
'''
Args:
x: Input tensor.
Returns:
local_out, multi_out, hidden_out, global_out
'''
outs = super().forward(x, return_full_list=True)
if len(outs) == 2:
conv_out, fc_out = outs
else:
conv_before_out, res_out, conv_after_out, fc_out = outs
conv_out = conv_before_out + res_out + conv_after_out
local_out = conv_out[self.local_idx]
if self.multi_idx is not None:
multi_out = conv_out[self.multi_idx]
else:
multi_out = None
if len(fc_out) > 0:
if self.fc_idx is not None:
hidden_out = fc_out[self.fc_idx]
else:
hidden_out = None
global_out = fc_out[-1]
else:
hidden_out = None
global_out = None
conv_out = conv_out[self.conv_idx]
return local_out, conv_out, multi_out, hidden_out, global_out
return Encoder
class ConvnetEncoder(create_encoder(Convnet)):
pass
class FoldedConvnetEncoder(create_encoder(FoldedConvnet)):
pass
class ResnetEncoder(create_encoder(ResNet)):
pass
class FoldedResnetEncoder(create_encoder(FoldedResNet)):
pass