liyy201912's picture
Upload folder using huggingface_hub
cc0dd3c
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional
import torch
from mmengine.model import BaseModule
from mmpretrain.registry import MODELS
@MODELS.register_module()
class MIMHead(BaseModule):
"""Pre-training head for Masked Image Modeling.
Args:
loss (dict): Config dict for module of loss functions.
"""
def __init__(self, loss: dict) -> None:
super().__init__()
self.loss_module = MODELS.build(loss)
def loss(self,
pred: torch.Tensor,
target: torch.Tensor,
mask: Optional[torch.Tensor] = None) -> torch.Tensor:
"""Forward head.
Args:
pred (torch.Tensor): Predictions with shape B x L x C.
target (torch.Tensor): Targets with shape B x L x C.
mask (torch.Tensor): Mask with shape B x L.
Returns:
torch.Tensor: The loss tensor.
"""
loss = self.loss_module(pred, target, mask)
return loss