from segmentation_models_pytorch.losses import TverskyLoss LOSSES = { 'tversky': TverskyLoss, } def get_loss_from_dict(loss_name, config): """Gets the proper loss given a loss name. Args: loss_name (str): name of the loss config: config object Raises: KeyError: thrown if loss key is not present in dict Returns: loss: pytorch loss """ try: loss_to_use = LOSSES[loss_name] except KeyError: error_msg = f"{loss_name} is not an implemented loss" error_msg = f"{error_msg}. Available loss functions: {LOSSES.keys()}" raise KeyError(error_msg) if loss_name == 'tversky': loss = loss_to_use(mode=config.LOSS.MODE, classes=config.LOSS.CLASSES, log_loss=config.LOSS.LOG, from_logits=config.LOSS.LOGITS, smooth=config.LOSS.SMOOTH, ignore_index=config.LOSS.IGNORE_INDEX, eps=config.LOSS.EPS, alpha=config.LOSS.ALPHA, beta=config.LOSS.BETA, gamma=config.LOSS.GAMMA) return loss def build_loss(config): """ Builds the loss function given a configuration object. Args: config: config object Returns: loss_to_use: pytorch loss function """ loss_name = config.LOSS.NAME loss_to_use = get_loss_from_dict(loss_name, config) return loss_to_use