File size: 1,851 Bytes
			
			| 3ef1661 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 | from .losses import *
from mono.utils.comm import get_func
import os
def build_from_cfg(cfg, default_args=None):
    """Build a module from config dict.
    Args:
        cfg (dict): Config dict. It should at least contain the key "type".
        default_args (dict, optional): Default initialization arguments.
    Returns:
        object: The constructed object.
    """
    if not isinstance(cfg, dict):
        raise TypeError(f'cfg must be a dict, but got {type(cfg)}')
    if 'type' not in cfg:
        raise RuntimeError('should contain the loss name')
    args = cfg.copy()
    
    obj_name = args.pop('type')
    obj_path = os.path.dirname(__file__).split(os.getcwd() + '/')[-1].replace('/', '.') + '.losses.' + obj_name 
    
    obj_cls = get_func(obj_path)(**args)
    
    if obj_cls is None:
        raise KeyError(f'cannot find {obj_name}.')
    return obj_cls
        
        
def build_criterions(cfg):
    if 'losses' not in cfg:
        raise RuntimeError('Losses have not been configured.')
    cfg_data_basic = cfg.data_basic
    criterions = dict()
    losses = cfg.losses
    if not isinstance(losses, dict):
        raise RuntimeError(f'Cannot initial losses with the type {type(losses)}')
    for key, loss_list in losses.items():
        criterions[key] = []
        for loss_cfg_i in loss_list:
            # update the canonical_space configs to the current loss cfg
            loss_cfg_i.update(cfg_data_basic)
            if 'out_channel' in loss_cfg_i:
                loss_cfg_i.update(out_channel=cfg.out_channel)  # classification loss need to update the channels
            obj_cls = build_from_cfg(loss_cfg_i)
            criterions[key].append(obj_cls)
    return criterions
            
            
        
            
            
            
            
            
            
        
    
  
 | 
