File size: 2,240 Bytes
fa84113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
from .transforms import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD


def resolve_input_config(args, model_config=None, model=None):
    if not isinstance(args, dict):
        args = vars(args)
    input_config = {}
    if not model_config and model is not None and hasattr(model, 'config'):
        model_config = model.config

    # Resolve input/image size
    in_chans = 3
    input_size = (in_chans, 512, 512)

    if 'input_size' in model_config:
        input_size = tuple(model_config['input_size'])
    elif 'image_size' in model_config:
        input_size = (in_chans,) + tuple(model_config['image_size'])
    assert isinstance(input_size, tuple) and len(input_size) == 3
    input_config['input_size'] = input_size

    # resolve interpolation method
    input_config['interpolation'] = 'bicubic'
    if 'interpolation' in args and args['interpolation']:
        input_config['interpolation'] = args['interpolation']
    elif 'interpolation' in model_config:
        input_config['interpolation'] = model_config['interpolation']

    # resolve dataset + model mean for normalization
    input_config['mean'] = IMAGENET_DEFAULT_MEAN
    if 'mean' in args and args['mean'] is not None:
        mean = tuple(args['mean'])
        if len(mean) == 1:
            mean = tuple(list(mean) * in_chans)
        else:
            assert len(mean) == in_chans
        input_config['mean'] = mean
    elif 'mean' in model_config:
        input_config['mean'] = model_config['mean']

    # resolve dataset + model std deviation for normalization
    input_config['std'] = IMAGENET_DEFAULT_STD
    if 'std' in args and args['std'] is not None:
        std = tuple(args['std'])
        if len(std) == 1:
            std = tuple(list(std) * in_chans)
        else:
            assert len(std) == in_chans
        input_config['std'] = std
    elif 'std' in model_config:
        input_config['std'] = model_config['std']

    # resolve letterbox fill color
    input_config['fill_color'] = 'mean'
    if 'fill_color' in args and args['fill_color'] is not None:
        input_config['fill_color'] = args['fill_color']
    elif 'fill_color' in model_config:
        input_config['fill_color'] = model_config['fill_color']

    return input_config