intfloat Samoed commited on
Commit
4c104a7
·
verified ·
1 Parent(s): 078d6bc

Create custom_st.py (#2)

Browse files

- Upload README.md with huggingface_hub (d553fa4a69fefa90695ced98a56771ed8c23c647)
- Create custom_st.py (d4b7753e4b68e46a1b04dbcae8965e24cc57054f)
- Update custom_st.py (31932a6c194ac72e9188576ba1bc19132c13711d)
- use only `<|image|>` (cbde83efa53606c8ab4326f31ea9ea4c737f39a9)
- return `<|begin_of_text|>` (46f332ef39fac617a5736ffabc0733914b60b440)
- remove `<|begin_of_text|>` (519b6b5cdf39e4437db4eea2c67f263892872bfe)
- Update README.md (b620de43ac130d2c246b250e3d4a576f276a211e)
- Update custom_st.py (7f44eb9c6998bd23545dcdba73c2518e4e4d69f2)


Co-authored-by: Solomatin Roman <[email protected]>

Files changed (1) hide show
  1. custom_st.py +110 -0
custom_st.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import BytesIO
2
+ from typing import Any, Dict, Optional, List
3
+ import torch
4
+ from PIL import Image
5
+ from transformers import AutoProcessor, MllamaForConditionalGeneration
6
+ from sentence_transformers.models import Transformer as BaseTransformer
7
+
8
+
9
+ class MultiModalTransformer(BaseTransformer):
10
+ def __init__(
11
+ self,
12
+ model_name_or_path: str,
13
+ cache_dir: Optional[str] = None,
14
+ tokenizer_args: Optional[Dict[str, Any]] = None,
15
+ **kwargs,
16
+ ):
17
+ super().__init__(model_name_or_path, **kwargs)
18
+ if tokenizer_args is None:
19
+ tokenizer_args = {}
20
+
21
+ # Initialize processor
22
+ self.processor = AutoProcessor.from_pretrained(
23
+ model_name_or_path, cache_dir=cache_dir, **tokenizer_args
24
+ )
25
+
26
+ def _load_model(
27
+ self,
28
+ model_name_or_path: str,
29
+ config,
30
+ cache_dir: str,
31
+ backend: str,
32
+ is_peft_model: bool,
33
+ **model_args,
34
+ ) -> None:
35
+ self.auto_model = MllamaForConditionalGeneration.from_pretrained(
36
+ model_name_or_path, torch_dtype=torch.bfloat16, cache_dir=cache_dir, **model_args
37
+ )
38
+
39
+ def forward(
40
+ self, features: Dict[str, torch.Tensor], **kwargs
41
+ ) -> Dict[str, torch.Tensor]:
42
+ # Process inputs through the model
43
+ outputs = self.auto_model(
44
+ **features,
45
+ return_dict=True,
46
+ output_hidden_states=True,
47
+ **kwargs
48
+ )
49
+
50
+ # Apply last pooling and normalization
51
+ last_hidden_state = outputs.hidden_states[-1]
52
+ attention_mask = features["attention_mask"]
53
+ sentence_embedding = self._last_pooling(last_hidden_state, attention_mask)
54
+
55
+ features.update({"sentence_embedding": sentence_embedding})
56
+ return features
57
+
58
+ def _last_pooling(self, last_hidden_state: torch.Tensor, attention_mask: torch.Tensor) -> torch.Tensor:
59
+ """Apply last token pooling and L2 normalization"""
60
+ sequence_lengths = attention_mask.sum(dim=1) - 1
61
+ batch_size = last_hidden_state.shape[0]
62
+ reps = last_hidden_state[torch.arange(batch_size, device=last_hidden_state.device), sequence_lengths]
63
+ return torch.nn.functional.normalize(reps, p=2, dim=-1)
64
+
65
+ def tokenize(self, texts: List[List[Dict]] | List[str]) -> Dict[str, torch.Tensor]:
66
+ def process_text_item(item):
67
+ if isinstance(item, str):
68
+ return item, []
69
+
70
+ text, images = "", []
71
+ for sub_item in item:
72
+ if sub_item["type"] == "text":
73
+ text += sub_item["content"]
74
+ elif sub_item["type"] in ["image_bytes", "image_path"]:
75
+ text += "<|image|>"
76
+ if sub_item["type"] == "image_bytes":
77
+ img = Image.open(BytesIO(sub_item["content"])).convert("RGB")
78
+ else:
79
+ img = Image.open(sub_item["content"]).convert("RGB")
80
+ images.append(img)
81
+ else:
82
+ raise ValueError(f"Unknown data type {sub_item['type']}")
83
+ return text, images
84
+
85
+ all_texts, all_images = [], []
86
+ for item in texts:
87
+ text, images = process_text_item(item)
88
+ all_texts.append(text)
89
+ all_images.extend(images)
90
+
91
+ # Process inputs through the processor
92
+ if all_images:
93
+ inputs = self.processor(
94
+ text=all_texts,
95
+ images=all_images,
96
+ padding="longest",
97
+ truncation=True,
98
+ max_length=self.max_seq_length,
99
+ return_tensors="pt"
100
+ )
101
+ else:
102
+ inputs = self.processor(
103
+ text=all_texts,
104
+ padding="longest",
105
+ truncation=True,
106
+ max_length=self.max_seq_length,
107
+ return_tensors="pt"
108
+ )
109
+
110
+ return inputs