File size: 1,555 Bytes
ab687e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from segmentation_models_pytorch.losses import TverskyLoss


LOSSES = {
    'tversky': TverskyLoss,
}


def get_loss_from_dict(loss_name, config):
    """Gets the proper loss given a loss name.

    Args:
        loss_name (str): name of the loss
        config: config object

    Raises:
        KeyError: thrown if loss key is not present in dict

    Returns:
        loss: pytorch loss
    """

    try:

        loss_to_use = LOSSES[loss_name]

    except KeyError:

        error_msg = f"{loss_name} is not an implemented loss"

        error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}"

        raise KeyError(error_msg)

    if loss_name == 'tversky':
        loss = loss_to_use(mode=config.LOSS.MODE,
                           classes=config.LOSS.CLASSES,
                           log_loss=config.LOSS.LOG,
                           from_logits=config.LOSS.LOGITS,
                           smooth=config.LOSS.SMOOTH,
                           ignore_index=config.LOSS.IGNORE_INDEX,
                           eps=config.LOSS.EPS,
                           alpha=config.LOSS.ALPHA,
                           beta=config.LOSS.BETA,
                           gamma=config.LOSS.GAMMA)
        return loss


def build_loss(config):
    """
    Builds the loss function given a configuration object.

    Args:
        config: config object

    Returns:
        loss_to_use: pytorch loss function
    """

    loss_name = config.LOSS.NAME

    loss_to_use = get_loss_from_dict(loss_name, config)

    return loss_to_use