File size: 6,469 Bytes
209fbaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
import os
import json
import torch
from torch import nn
import torch.nn.functional as F
from typing import List, Dict

from transformers import AutoModel

os.environ["TOKENIZERS_PARALLELISM"] = "false"

class KananaEmbeddingWrapper(nn.Module):
    def __init__(self, model_name_or_path: str, trust_remote_code=True, device: str = "cpu", max_seq_length:int=None):
        """
        Initialize the KananaEmbeddingWrapper.
        
        Args:
            model_name_or_path: Path or name of the pretrained model
            trust_remote_code: Whether to trust remote code when loading the model
            device: Device to load the model on (e.g., 'cpu', 'cuda')
        """
        super(KananaEmbeddingWrapper, self).__init__()
        self.model_name_or_path = model_name_or_path
        self.trust_remote_code = trust_remote_code
        self.device = device
        self.kanana2vec = AutoModel.from_pretrained(
            model_name_or_path, trust_remote_code=trust_remote_code
        ).to(self.device)

        self.max_seq_length = max_seq_length if max_seq_length is not None else self.kanana2vec.config.max_position_embeddings

    def get_sentence_embedding_dimension(self) -> int:
        """
        Returns the dimension of the sentence embeddings.
        
        Returns:
            Dimensionality of the sentence embeddings
        """
        return self.kanana2vec.config.hidden_size
    
    def get_max_seq_length(self) -> int:
        """
        Returns the maximum sequence length this module can process.
        
        Returns:
            Maximum sequence length
        """
        return self.max_seq_length
    
    def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]:
        """
        Tokenize input texts.
        
        Args:
            texts: List of input texts to tokenize
            
        Returns:
            Dictionary containing tokenized inputs
        """
        return self.kanana2vec.tokenizer(
            texts,
            padding=True,
            return_token_type_ids=False,
            return_tensors="pt",
            truncation=True,
            max_length=self.max_seq_length
        ).to(self.device)

    def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Forward pass of the module.
        
        Args:
            features: Dictionary with inputs including 'input_ids', 'attention_mask', etc.
        
        Returns:
            Dictionary with updated features including 'sentence_embedding'
        """
        # Extract only the required features for the model
        model_inputs = self._extract_model_inputs(features)
        
        # Create pool mask considering prompt length if available
        model_inputs["pool_mask"] = self._create_pool_mask(features)
        
        # Get embeddings from the model and normalize
        embedding = self.kanana2vec.forward(**model_inputs).embedding
        normalized_embedding = F.normalize(embedding, p=2, dim=1)
        
        # Update features with sentence embedding
        features['sentence_embedding'] = normalized_embedding
        return features
    
    def _extract_model_inputs(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
        """
        Extract only the inputs needed for the model.
        
        Args:
            features: Complete feature dictionary
            
        Returns:
            Dictionary with only the required keys for the model
        """
        return {k: v for k, v in features.items() if k in ['input_ids', 'attention_mask']}
    
    def _create_pool_mask(self, features: Dict[str, torch.Tensor]) -> torch.Tensor:
        """
        Create a pool mask based on attention mask and prompt length.
        
        Args:
            features: Feature dictionary containing attention_mask and optionally prompt_length
            
        Returns:
            Pool mask tensor
        """
        pool_mask = features['attention_mask'].clone()
        if "prompt_length" in features:
            pool_mask[:, :features['prompt_length']] = 0
        return pool_mask
    
    def get_config_dict(self) -> Dict:
        """
        Returns a dictionary with the module's configuration.
        
        Returns:
            Dictionary with module configuration
        """
        return {
            "model_name_or_path": self.model_name_or_path,
            "trust_remote_code": self.trust_remote_code,
            "device": self.device,
            "hidden_size": self.get_sentence_embedding_dimension(),
            "max_seq_length": self.get_max_seq_length()
        }
    
    def save(self, save_dir: str) -> None:
        """
        Saves the module's configuration and model to the specified directory.
        
        Args:
            save_dir: Directory to save the module configuration
        """
        os.makedirs(save_dir, exist_ok=True)
        
        # Save model configuration
        config_path = os.path.join(save_dir, "kanana_embedding_config.json")
        with open(config_path, 'w', encoding='utf-8') as f:
            json.dump(self.get_config_dict(), f, ensure_ascii=False, indent=2)
        
        # Save the underlying model
        model_save_path = os.path.join(save_dir, "kanana2vec")
        self.kanana2vec.save_pretrained(model_save_path)
        
        print(f"KananaEmbeddingWrapper model saved to {save_dir}")
    
    @staticmethod
    def load(load_dir: str, device: str = "cpu") -> 'KananaEmbeddingWrapper':
        """
        Loads a KananaEmbeddingWrapper model from the specified directory.
        
        Args:
            load_dir: Directory containing the saved module
            device: Device to load the model on
            
        Returns:
            Initialized KananaEmbeddingWrapper
        """
        # Load configuration
        config_path = os.path.join(load_dir, "kanana_embedding_config.json")
        with open(config_path, 'r', encoding='utf-8') as f:
            config = json.load(f)
        
        # Use the saved model path
        model_load_path = os.path.join(load_dir, "kanana2vec")
        
        # Create instance with saved configuration
        instance = KananaEmbeddingWrapper(
            model_name_or_path=model_load_path,
            trust_remote_code=config.get("trust_remote_code", True),
            device=device  # Use the provided device or default
        )
        
        return instance