Spaces:
Sleeping
Sleeping
from .efficientdet import EfficientDet, HeadNet | |
from .bench import DetBenchTrain, DetBenchPredict | |
from .config import get_efficientdet_config | |
from .helpers import load_pretrained, load_checkpoint | |
def create_model( | |
model_name, bench_task='', num_classes=None, pretrained=False, | |
checkpoint_path='', checkpoint_ema=False, **kwargs): | |
config = get_efficientdet_config(model_name) | |
return create_model_from_config( | |
config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained, | |
checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs) | |
def create_model_from_config( | |
config, bench_task='', num_classes=None, pretrained=False, | |
checkpoint_path='', checkpoint_ema=False, **kwargs): | |
pretrained_backbone = kwargs.pop('pretrained_backbone', True) | |
if pretrained or checkpoint_path: | |
pretrained_backbone = False # no point in loading backbone weights | |
# Config overrides, override some config values via kwargs. | |
overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss') | |
for ov in overrides: | |
value = kwargs.pop(ov, None) | |
if value is not None: | |
setattr(config, ov, value) | |
labeler = kwargs.pop('bench_labeler', False) | |
# create the base model | |
model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs) | |
# pretrained weights are always spec'd for original config, load them before we change the model | |
if pretrained: | |
load_pretrained(model, config.url) | |
# reset model head if num_classes doesn't match configs | |
if num_classes is not None and num_classes != config.num_classes: | |
model.reset_head(num_classes=num_classes) | |
# load an argument specified training checkpoint | |
if checkpoint_path: | |
load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema) | |
# wrap model in task specific training/prediction bench if set | |
if bench_task == 'train': | |
model = DetBenchTrain(model, create_labeler=labeler) | |
elif bench_task == 'predict': | |
model = DetBenchPredict(model) | |
return model | |