File size: 2,151 Bytes
28c256d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.

import mmengine
import numpy as np


def write_to_json(config: dict, filename: str):
    """save config to json file.

    Args:
        config (dict): Config to be saved.
        filename (str): Path to save config.
    """

    with open(filename, 'w', encoding='utf-8') as f:
        mmengine.dump(config, f, file_format='json')


def expand_rates(dilation: tuple, config: dict) -> list:
    """expand dilation rate according to config.

    Args:
        dilation (int): _description_
        config (dict): config dict

    Returns:
        list: list of expanded dilation rates
    """
    exp_rate = config['exp_rate']

    large_rates = []
    small_rates = []
    for _ in range(config['num_branches'] // 2):
        large_rates.append(
            tuple([
                np.clip(
                    int(round((1 + exp_rate) * dilation[0])), config['mmin'],
                    config['mmax']).item(),
                np.clip(
                    int(round((1 + exp_rate) * dilation[1])), config['mmin'],
                    config['mmax']).item()
            ]))
        small_rates.append(
            tuple([
                np.clip(
                    int(round((1 - exp_rate) * dilation[0])), config['mmin'],
                    config['mmax']).item(),
                np.clip(
                    int(round((1 - exp_rate) * dilation[1])), config['mmin'],
                    config['mmax']).item()
            ]))

    small_rates.reverse()

    if config['num_branches'] % 2 == 0:
        rate_list = small_rates + large_rates
    else:
        rate_list = small_rates + [dilation] + large_rates

    unique_rate_list = list(set(rate_list))
    unique_rate_list.sort(key=rate_list.index)
    return unique_rate_list


def get_single_padding(kernel_size: int,
                       stride: int = 1,
                       dilation: int = 1) -> int:
    padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
    return padding