Spaces:
Runtime error
Runtime error
'''Basic cortex_DIM encoder. | |
''' | |
import torch | |
from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet | |
from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet | |
def create_encoder(Module): | |
class Encoder(Module): | |
'''Encoder used for cortex_DIM. | |
''' | |
def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs): | |
''' | |
Args: | |
args: Arguments for parent class. | |
local_idx: Index in list of convolutional layers for local features. | |
multi_idx: Index in list of convolutional layers for multiple globals. | |
conv_idx: Index in list of convolutional layers for intermediate features. | |
fc_idx: Index in list of fully-connected layers for intermediate features. | |
kwargs: Keyword arguments for the parent class. | |
''' | |
super().__init__(*args, **kwargs) | |
if local_idx is None: | |
raise ValueError('`local_idx` must be set') | |
conv_idx = conv_idx or local_idx | |
self.local_idx = local_idx | |
self.multi_idx = multi_idx | |
self.conv_idx = conv_idx | |
self.fc_idx = fc_idx | |
def forward(self, x: torch.Tensor): | |
''' | |
Args: | |
x: Input tensor. | |
Returns: | |
local_out, multi_out, hidden_out, global_out | |
''' | |
outs = super().forward(x, return_full_list=True) | |
if len(outs) == 2: | |
conv_out, fc_out = outs | |
else: | |
conv_before_out, res_out, conv_after_out, fc_out = outs | |
conv_out = conv_before_out + res_out + conv_after_out | |
local_out = conv_out[self.local_idx] | |
if self.multi_idx is not None: | |
multi_out = conv_out[self.multi_idx] | |
else: | |
multi_out = None | |
if len(fc_out) > 0: | |
if self.fc_idx is not None: | |
hidden_out = fc_out[self.fc_idx] | |
else: | |
hidden_out = None | |
global_out = fc_out[-1] | |
else: | |
hidden_out = None | |
global_out = None | |
conv_out = conv_out[self.conv_idx] | |
return local_out, conv_out, multi_out, hidden_out, global_out | |
return Encoder | |
class ConvnetEncoder(create_encoder(Convnet)): | |
pass | |
class FoldedConvnetEncoder(create_encoder(FoldedConvnet)): | |
pass | |
class ResnetEncoder(create_encoder(ResNet)): | |
pass | |
class FoldedResnetEncoder(create_encoder(FoldedResNet)): | |
pass | |