import torch.nn as nn from src.core import register CrossEntropyLoss = register(nn.CrossEntropyLoss)