File size: 3,987 Bytes
d4b7753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer

class MultiModalTransformer(BaseTransformer):
    def __init__(
        self,
        model_name_or_path: str,
        cache_dir: Optional[str] = None,
        tokenizer_args: Optional[Dict[str, Any]] = None,
        **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}
            
        # Initialize processor and set padding side
        self.processor = AutoProcessor.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **tokenizer_args
        )
        
        # Configure model settings
        config = self.auto_model.config
        if hasattr(config, 'use_cache'):
            config.use_cache = False

        padding_side = "right"
        self.processor.tokenizer.padding_side = padding_side
        config.padding_side = padding_side
        self.auto_model.padding_side = padding_side

    def forward(
        self, features: Dict[str, torch.Tensor], **kwargs
    ) -> Dict[str, torch.Tensor]:
        # Process inputs through the model
        outputs = self.auto_model(
            **features,
            return_dict=True,
            output_hidden_states=True,
            **kwargs
        )
        
        # Apply last pooling and normalization
        last_hidden_state = outputs.hidden_states[-1]
        attention_mask = features["attention_mask"]
        sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
        
        features.update({"sentence_embedding": sentence_embedding})
        return features

    def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Apply last token pooling and L2 normalization"""
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_state.shape[0]
        reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
        return torch.nn.functional.normalize(reps, p=2, dim=-1)

    def tokenize(self, texts: List[Dict] | List[str]) -> Dict[str, torch.Tensor]:
        def process_text_item(item):
            if isinstance(item, str):
                return item, []
            
            text, images = "", []
            for sub_item in item:
                if sub_item["type"] == "text":
                    text += sub_item["content"]
                elif sub_item["type"] in ["image_bytes", "image_path"]:
                    text += "<|image|><|begin_of_text|> Represent the given image"
                    if sub_item["type"] == "image_bytes":
                        img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
                    else:
                        img = Image.open(sub_item["content"]).convert("RGB")
                    images.append(img)
                else:
                    raise ValueError(f"Unknown data type {sub_item['type']}")
            return text, images

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.extend(images)

        # Process inputs through the processor
        if all_images:
            inputs = self.processor(
                text=all_texts,
                images=all_images,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=all_texts,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        
        return inputs