Spaces:
Sleeping
Sleeping
| from .efficientdet import EfficientDet, HeadNet | |
| from .bench import DetBenchTrain, DetBenchPredict | |
| from .config import get_efficientdet_config | |
| from .helpers import load_pretrained, load_checkpoint | |
| def create_model( | |
| model_name, bench_task='', num_classes=None, pretrained=False, | |
| checkpoint_path='', checkpoint_ema=False, **kwargs): | |
| config = get_efficientdet_config(model_name) | |
| return create_model_from_config( | |
| config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained, | |
| checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs) | |
| def create_model_from_config( | |
| config, bench_task='', num_classes=None, pretrained=False, | |
| checkpoint_path='', checkpoint_ema=False, **kwargs): | |
| pretrained_backbone = kwargs.pop('pretrained_backbone', True) | |
| if pretrained or checkpoint_path: | |
| pretrained_backbone = False # no point in loading backbone weights | |
| # Config overrides, override some config values via kwargs. | |
| overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss') | |
| for ov in overrides: | |
| value = kwargs.pop(ov, None) | |
| if value is not None: | |
| setattr(config, ov, value) | |
| labeler = kwargs.pop('bench_labeler', False) | |
| # create the base model | |
| model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs) | |
| # pretrained weights are always spec'd for original config, load them before we change the model | |
| if pretrained: | |
| load_pretrained(model, config.url) | |
| # reset model head if num_classes doesn't match configs | |
| if num_classes is not None and num_classes != config.num_classes: | |
| model.reset_head(num_classes=num_classes) | |
| # load an argument specified training checkpoint | |
| if checkpoint_path: | |
| load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema) | |
| # wrap model in task specific training/prediction bench if set | |
| if bench_task == 'train': | |
| model = DetBenchTrain(model, create_labeler=labeler) | |
| elif bench_task == 'predict': | |
| model = DetBenchPredict(model) | |
| return model | |