Spaces:
Runtime error
Runtime error
File size: 2,663 Bytes
01df1d6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 |
'''Basic cortex_DIM encoder.
'''
import torch
from cortex_DIM.nn_modules.convnet import Convnet, FoldedConvnet
from cortex_DIM.nn_modules.resnet import ResNet, FoldedResNet
def create_encoder(Module):
class Encoder(Module):
'''Encoder used for cortex_DIM.
'''
def __init__(self, *args, local_idx=None, multi_idx=None, conv_idx=None, fc_idx=None, **kwargs):
'''
Args:
args: Arguments for parent class.
local_idx: Index in list of convolutional layers for local features.
multi_idx: Index in list of convolutional layers for multiple globals.
conv_idx: Index in list of convolutional layers for intermediate features.
fc_idx: Index in list of fully-connected layers for intermediate features.
kwargs: Keyword arguments for the parent class.
'''
super().__init__(*args, **kwargs)
if local_idx is None:
raise ValueError('`local_idx` must be set')
conv_idx = conv_idx or local_idx
self.local_idx = local_idx
self.multi_idx = multi_idx
self.conv_idx = conv_idx
self.fc_idx = fc_idx
def forward(self, x: torch.Tensor):
'''
Args:
x: Input tensor.
Returns:
local_out, multi_out, hidden_out, global_out
'''
outs = super().forward(x, return_full_list=True)
if len(outs) == 2:
conv_out, fc_out = outs
else:
conv_before_out, res_out, conv_after_out, fc_out = outs
conv_out = conv_before_out + res_out + conv_after_out
local_out = conv_out[self.local_idx]
if self.multi_idx is not None:
multi_out = conv_out[self.multi_idx]
else:
multi_out = None
if len(fc_out) > 0:
if self.fc_idx is not None:
hidden_out = fc_out[self.fc_idx]
else:
hidden_out = None
global_out = fc_out[-1]
else:
hidden_out = None
global_out = None
conv_out = conv_out[self.conv_idx]
return local_out, conv_out, multi_out, hidden_out, global_out
return Encoder
class ConvnetEncoder(create_encoder(Convnet)):
pass
class FoldedConvnetEncoder(create_encoder(FoldedConvnet)):
pass
class ResnetEncoder(create_encoder(ResNet)):
pass
class FoldedResnetEncoder(create_encoder(FoldedResNet)):
pass
|