santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
2.13 kB
from .efficientdet import EfficientDet, HeadNet
from .bench import DetBenchTrain, DetBenchPredict
from .config import get_efficientdet_config
from .helpers import load_pretrained, load_checkpoint
def create_model(
model_name, bench_task='', num_classes=None, pretrained=False,
checkpoint_path='', checkpoint_ema=False, **kwargs):
config = get_efficientdet_config(model_name)
return create_model_from_config(
config, bench_task=bench_task, num_classes=num_classes, pretrained=pretrained,
checkpoint_path=checkpoint_path, checkpoint_ema=checkpoint_ema, **kwargs)
def create_model_from_config(
config, bench_task='', num_classes=None, pretrained=False,
checkpoint_path='', checkpoint_ema=False, **kwargs):
pretrained_backbone = kwargs.pop('pretrained_backbone', True)
if pretrained or checkpoint_path:
pretrained_backbone = False # no point in loading backbone weights
# Config overrides, override some config values via kwargs.
overrides = ('redundant_bias', 'label_smoothing', 'new_focal', 'jit_loss')
for ov in overrides:
value = kwargs.pop(ov, None)
if value is not None:
setattr(config, ov, value)
labeler = kwargs.pop('bench_labeler', False)
# create the base model
model = EfficientDet(config, pretrained_backbone=pretrained_backbone, **kwargs)
# pretrained weights are always spec'd for original config, load them before we change the model
if pretrained:
load_pretrained(model, config.url)
# reset model head if num_classes doesn't match configs
if num_classes is not None and num_classes != config.num_classes:
model.reset_head(num_classes=num_classes)
# load an argument specified training checkpoint
if checkpoint_path:
load_checkpoint(model, checkpoint_path, use_ema=checkpoint_ema)
# wrap model in task specific training/prediction bench if set
if bench_task == 'train':
model = DetBenchTrain(model, create_labeler=labeler)
elif bench_task == 'predict':
model = DetBenchPredict(model)
return model