File size: 459 Bytes
590af54
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch.nn as nn
import pyrootutils


pyrootutils.setup_root(__file__, indicator='.project-root', pythonpath=True)

class DiscreteModleIdentity(nn.Module):

    def __init__(self) -> None:
        super().__init__()
        self.model = nn.Identity()

    def forward(self, image_embeds, input_ids=None, text_attention_mask=None, text_embeds=None):
        return

    def encode_image_embeds(self, image_embeds):
        return self.model(image_embeds)