File size: 3,983 Bytes
4c104a7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer


class MultiModalTransformer(BaseTransformer):
    def __init__(
            self,
            model_name_or_path: str,
            cache_dir: Optional[str] = None,
            tokenizer_args: Optional[Dict[str, Any]] = None,
            **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}

        # Initialize processor
        self.processor = AutoProcessor.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **tokenizer_args
        )

    def _load_model(
            self,
            model_name_or_path: str,
            config,
            cache_dir: str,
            backend: str,
            is_peft_model: bool,
            **model_args,
    ) -> None:
        self.auto_model = MllamaForConditionalGeneration.from_pretrained(
            model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
        )

    def forward(
            self, features: Dict[str, torch.Tensor], **kwargs
    ) -> Dict[str, torch.Tensor]:
        # Process inputs through the model
        outputs = self.auto_model(
            **features,
            return_dict=True,
            output_hidden_states=True,
            **kwargs
        )

        # Apply last pooling and normalization
        last_hidden_state = outputs.hidden_states[-1]
        attention_mask = features["attention_mask"]
        sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)

        features.update({"sentence_embedding": sentence_embedding})
        return features

    def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Apply last token pooling and L2 normalization"""
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_state.shape[0]
        reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
        return torch.nn.functional.normalize(reps, p=2, dim=-1)

    def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
        def process_text_item(item):
            if isinstance(item, str):
                return item, []

            text, images = "", []
            for sub_item in item:
                if sub_item["type"] == "text":
                    text += sub_item["content"]
                elif sub_item["type"] in ["image_bytes", "image_path"]:
                    text += "<|image|>"
                    if sub_item["type"] == "image_bytes":
                        img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
                    else:
                        img = Image.open(sub_item["content"]).convert("RGB")
                    images.append(img)
                else:
                    raise ValueError(f"Unknown data type {sub_item['type']}")
            return text, images

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.extend(images)

        # Process inputs through the processor
        if all_images:
            inputs = self.processor(
                text=all_texts,
                images=all_images,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=all_texts,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )

        return inputs