sadimanna's picture
Upload 20 files
d6def08
raw
history blame
782 Bytes
from typing import Tuple, Type
from torch import nn
class Base(object):
OPTIONS = ['resnet18', 'resnet50', 'resnet101']
@staticmethod
def from_name(name: str) -> Type['Base']:
if name == 'resnet18':
from backbone.resnet18 import ResNet18
return ResNet18
elif name == 'resnet50':
from backbone.resnet50 import ResNet50
return ResNet50
elif name == 'resnet101':
from backbone.resnet101 import ResNet101
return ResNet101
else:
raise ValueError
def __init__(self, pretrained: bool):
super().__init__()
self._pretrained = pretrained
def features(self) -> Tuple[nn.Module, nn.Module, int, int]:
raise NotImplementedError