File size: 782 Bytes
d6def08 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
from typing import Tuple, Type
from torch import nn
class Base(object):
OPTIONS = ['resnet18', 'resnet50', 'resnet101']
@staticmethod
def from_name(name: str) -> Type['Base']:
if name == 'resnet18':
from backbone.resnet18 import ResNet18
return ResNet18
elif name == 'resnet50':
from backbone.resnet50 import ResNet50
return ResNet50
elif name == 'resnet101':
from backbone.resnet101 import ResNet101
return ResNet101
else:
raise ValueError
def __init__(self, pretrained: bool):
super().__init__()
self._pretrained = pretrained
def features(self) -> Tuple[nn.Module, nn.Module, int, int]:
raise NotImplementedError
|