File size: 1,555 Bytes
ab687e7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from segmentation_models_pytorch.losses import TverskyLoss
LOSSES = {
'tversky': TverskyLoss,
}
def get_loss_from_dict(loss_name, config):
"""Gets the proper loss given a loss name.
Args:
loss_name (str): name of the loss
config: config object
Raises:
KeyError: thrown if loss key is not present in dict
Returns:
loss: pytorch loss
"""
try:
loss_to_use = LOSSES[loss_name]
except KeyError:
error_msg = f"{loss_name} is not an implemented loss"
error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}"
raise KeyError(error_msg)
if loss_name == 'tversky':
loss = loss_to_use(mode=config.LOSS.MODE,
classes=config.LOSS.CLASSES,
log_loss=config.LOSS.LOG,
from_logits=config.LOSS.LOGITS,
smooth=config.LOSS.SMOOTH,
ignore_index=config.LOSS.IGNORE_INDEX,
eps=config.LOSS.EPS,
alpha=config.LOSS.ALPHA,
beta=config.LOSS.BETA,
gamma=config.LOSS.GAMMA)
return loss
def build_loss(config):
"""
Builds the loss function given a configuration object.
Args:
config: config object
Returns:
loss_to_use: pytorch loss function
"""
loss_name = config.LOSS.NAME
loss_to_use = get_loss_from_dict(loss_name, config)
return loss_to_use
|