santit96's picture
Create the streamlit app that classifies the trash in an image into classes
fa84113
raw
history blame
2.24 kB
from .transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD
def resolve_input_config(args, model_config=None, model=None):
if not isinstance(args, dict):
args = vars(args)
input_config = {}
if not model_config and model is not None and hasattr(model, 'config'):
model_config = model.config
# Resolve input/image size
in_chans = 3
input_size = (in_chans, 512, 512)
if 'input_size' in model_config:
input_size = tuple(model_config['input_size'])
elif 'image_size' in model_config:
input_size = (in_chans,) + tuple(model_config['image_size'])
assert isinstance(input_size, tuple) and len(input_size) == 3
input_config['input_size'] = input_size
# resolve interpolation method
input_config['interpolation'] = 'bicubic'
if 'interpolation' in args and args['interpolation']:
input_config['interpolation'] = args['interpolation']
elif 'interpolation' in model_config:
input_config['interpolation'] = model_config['interpolation']
# resolve dataset + model mean for normalization
input_config['mean'] = IMAGENET_DEFAULT_MEAN
if 'mean' in args and args['mean'] is not None:
mean = tuple(args['mean'])
if len(mean) == 1:
mean = tuple(list(mean) * in_chans)
else:
assert len(mean) == in_chans
input_config['mean'] = mean
elif 'mean' in model_config:
input_config['mean'] = model_config['mean']
# resolve dataset + model std deviation for normalization
input_config['std'] = IMAGENET_DEFAULT_STD
if 'std' in args and args['std'] is not None:
std = tuple(args['std'])
if len(std) == 1:
std = tuple(list(std) * in_chans)
else:
assert len(std) == in_chans
input_config['std'] = std
elif 'std' in model_config:
input_config['std'] = model_config['std']
# resolve letterbox fill color
input_config['fill_color'] = 'mean'
if 'fill_color' in args and args['fill_color'] is not None:
input_config['fill_color'] = args['fill_color']
elif 'fill_color' in model_config:
input_config['fill_color'] = model_config['fill_color']
return input_config