File size: 3,983 Bytes
7f44eb9
d4b7753
 
 
 
 
 
7f44eb9
d4b7753
 
7f44eb9
 
 
 
 
d4b7753
 
 
 
7f44eb9
 
d4b7753
 
 
 
7f44eb9
 
 
 
 
 
 
 
 
 
 
 
d4b7753
 
7f44eb9
d4b7753
 
 
 
 
 
 
 
7f44eb9
d4b7753
 
 
 
7f44eb9
d4b7753
 
 
 
 
 
 
 
 
 
7f44eb9
d4b7753
 
 
7f44eb9
d4b7753
 
 
 
 
519b6b5
d4b7753
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7f44eb9
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from io import BytesIO
from typing import Any, Dict, Optional, List
import torch
from PIL import Image
from transformers import AutoProcessor, MllamaForConditionalGeneration
from sentence_transformers.models import Transformer as BaseTransformer


class MultiModalTransformer(BaseTransformer):
    def __init__(
            self,
            model_name_or_path: str,
            cache_dir: Optional[str] = None,
            tokenizer_args: Optional[Dict[str, Any]] = None,
            **kwargs,
    ):
        super().__init__(model_name_or_path, **kwargs)
        if tokenizer_args is None:
            tokenizer_args = {}

        # Initialize processor
        self.processor = AutoProcessor.from_pretrained(
            model_name_or_path, cache_dir=cache_dir, **tokenizer_args
        )

    def _load_model(
            self,
            model_name_or_path: str,
            config,
            cache_dir: str,
            backend: str,
            is_peft_model: bool,
            **model_args,
    ) -> None:
        self.auto_model = MllamaForConditionalGeneration.from_pretrained(
            model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
        )

    def forward(
            self, features: Dict[str, torch.Tensor], **kwargs
    ) -> Dict[str, torch.Tensor]:
        # Process inputs through the model
        outputs = self.auto_model(
            **features,
            return_dict=True,
            output_hidden_states=True,
            **kwargs
        )

        # Apply last pooling and normalization
        last_hidden_state = outputs.hidden_states[-1]
        attention_mask = features["attention_mask"]
        sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)

        features.update({"sentence_embedding": sentence_embedding})
        return features

    def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
        """Apply last token pooling and L2 normalization"""
        sequence_lengths = attention_mask.sum(dim=1) - 1
        batch_size = last_hidden_state.shape[0]
        reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
        return torch.nn.functional.normalize(reps, p=2, dim=-1)

    def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
        def process_text_item(item):
            if isinstance(item, str):
                return item, []

            text, images = "", []
            for sub_item in item:
                if sub_item["type"] == "text":
                    text += sub_item["content"]
                elif sub_item["type"] in ["image_bytes", "image_path"]:
                    text += "<|image|>"
                    if sub_item["type"] == "image_bytes":
                        img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
                    else:
                        img = Image.open(sub_item["content"]).convert("RGB")
                    images.append(img)
                else:
                    raise ValueError(f"Unknown data type {sub_item['type']}")
            return text, images

        all_texts, all_images = [], []
        for item in texts:
            text, images = process_text_item(item)
            all_texts.append(text)
            all_images.extend(images)

        # Process inputs through the processor
        if all_images:
            inputs = self.processor(
                text=all_texts,
                images=all_images,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )
        else:
            inputs = self.processor(
                text=all_texts,
                padding="longest",
                truncation=True,
                max_length=self.max_seq_length,
                return_tensors="pt"
            )

        return inputs