Spaces:
Sleeping
Sleeping
File size: 2,127 Bytes
fa84113 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 |
from .efficientdet import EfficientDet, HeadNet
from .bench import DetBenchTrain, DetBenchPredict
from .config import get_efficientdet_config
from .helpers import load_pretrained, load_checkpoint
def create_model(
model_name, bench_task='', num_classes=None, pretrained=False,
checkpoint_path='', checkpoint_ema=False, **kwargs):
config = get_efficientdet_config(model_name)
return create_model_from_config(
config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained,
checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs)
def create_model_from_config(
config, bench_task='', num_classes=None, pretrained=False,
checkpoint_path='', checkpoint_ema=False, **kwargs):
pretrained_backbone = kwargs.pop('pretrained_backbone', True)
if pretrained or checkpoint_path:
pretrained_backbone = False # no point in loading backbone weights
# Config overrides, override some config values via kwargs.
overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss')
for ov in overrides:
value = kwargs.pop(ov, None)
if value is not None:
setattr(config, ov, value)
labeler = kwargs.pop('bench_labeler', False)
# create the base model
model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)
# pretrained weights are always spec'd for original config, load them before we change the model
if pretrained:
load_pretrained(model, config.url)
# reset model head if num_classes doesn't match configs
if num_classes is not None and num_classes != config.num_classes:
model.reset_head(num_classes=num_classes)
# load an argument specified training checkpoint
if checkpoint_path:
load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema)
# wrap model in task specific training/prediction bench if set
if bench_task == 'train':
model = DetBenchTrain(model, create_labeler=labeler)
elif bench_task == 'predict':
model = DetBenchPredict(model)
return model
|