from typing import Tuple, Type from torch import nn class Base(object): OPTIONS = ['resnet18', 'resnet50', 'resnet101'] @staticmethod def from_name(name: str) -> Type['Base']: if name == 'resnet18': from backbone.resnet18 import ResNet18 return ResNet18 elif name == 'resnet50': from backbone.resnet50 import ResNet50 return ResNet50 elif name == 'resnet101': from backbone.resnet101 import ResNet101 return ResNet101 else: raise ValueError def __init__(self, pretrained: bool): super().__init__() self._pretrained = pretrained def features(self) -> Tuple[nn.Module, nn.Module, int, int]: raise NotImplementedError