File size: 6,469 Bytes
209fbaf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 |
import os
import json
import torch
from torch import nn
import torch.nn.functional as F
from typing import List, Dict
from transformers import AutoModel
os.environ["TOKENIZERS_PARALLELISM"] = "false"
class KananaEmbeddingWrapper(nn.Module):
def __init__(self, model_name_or_path: str, trust_remote_code=True, device: str = "cpu", max_seq_length:int=None):
"""
Initialize the KananaEmbeddingWrapper.
Args:
model_name_or_path: Path or name of the pretrained model
trust_remote_code: Whether to trust remote code when loading the model
device: Device to load the model on (e.g., 'cpu', 'cuda')
"""
super(KananaEmbeddingWrapper, self).__init__()
self.model_name_or_path = model_name_or_path
self.trust_remote_code = trust_remote_code
self.device = device
self.kanana2vec = AutoModel.from_pretrained(
model_name_or_path, trust_remote_code=trust_remote_code
).to(self.device)
self.max_seq_length = max_seq_length if max_seq_length is not None else self.kanana2vec.config.max_position_embeddings
def get_sentence_embedding_dimension(self) -> int:
"""
Returns the dimension of the sentence embeddings.
Returns:
Dimensionality of the sentence embeddings
"""
return self.kanana2vec.config.hidden_size
def get_max_seq_length(self) -> int:
"""
Returns the maximum sequence length this module can process.
Returns:
Maximum sequence length
"""
return self.max_seq_length
def tokenize(self, texts: List[str]) -> Dict[str, torch.Tensor]:
"""
Tokenize input texts.
Args:
texts: List of input texts to tokenize
Returns:
Dictionary containing tokenized inputs
"""
return self.kanana2vec.tokenizer(
texts,
padding=True,
return_token_type_ids=False,
return_tensors="pt",
truncation=True,
max_length=self.max_seq_length
).to(self.device)
def forward(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Forward pass of the module.
Args:
features: Dictionary with inputs including 'input_ids', 'attention_mask', etc.
Returns:
Dictionary with updated features including 'sentence_embedding'
"""
# Extract only the required features for the model
model_inputs = self._extract_model_inputs(features)
# Create pool mask considering prompt length if available
model_inputs["pool_mask"] = self._create_pool_mask(features)
# Get embeddings from the model and normalize
embedding = self.kanana2vec.forward(**model_inputs).embedding
normalized_embedding = F.normalize(embedding, p=2, dim=1)
# Update features with sentence embedding
features['sentence_embedding'] = normalized_embedding
return features
def _extract_model_inputs(self, features: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
"""
Extract only the inputs needed for the model.
Args:
features: Complete feature dictionary
Returns:
Dictionary with only the required keys for the model
"""
return {k: v for k, v in features.items() if k in ['input_ids', 'attention_mask']}
def _create_pool_mask(self, features: Dict[str, torch.Tensor]) -> torch.Tensor:
"""
Create a pool mask based on attention mask and prompt length.
Args:
features: Feature dictionary containing attention_mask and optionally prompt_length
Returns:
Pool mask tensor
"""
pool_mask = features['attention_mask'].clone()
if "prompt_length" in features:
pool_mask[:, :features['prompt_length']] = 0
return pool_mask
def get_config_dict(self) -> Dict:
"""
Returns a dictionary with the module's configuration.
Returns:
Dictionary with module configuration
"""
return {
"model_name_or_path": self.model_name_or_path,
"trust_remote_code": self.trust_remote_code,
"device": self.device,
"hidden_size": self.get_sentence_embedding_dimension(),
"max_seq_length": self.get_max_seq_length()
}
def save(self, save_dir: str) -> None:
"""
Saves the module's configuration and model to the specified directory.
Args:
save_dir: Directory to save the module configuration
"""
os.makedirs(save_dir, exist_ok=True)
# Save model configuration
config_path = os.path.join(save_dir, "kanana_embedding_config.json")
with open(config_path, 'w', encoding='utf-8') as f:
json.dump(self.get_config_dict(), f, ensure_ascii=False, indent=2)
# Save the underlying model
model_save_path = os.path.join(save_dir, "kanana2vec")
self.kanana2vec.save_pretrained(model_save_path)
print(f"KananaEmbeddingWrapper model saved to {save_dir}")
@staticmethod
def load(load_dir: str, device: str = "cpu") -> 'KananaEmbeddingWrapper':
"""
Loads a KananaEmbeddingWrapper model from the specified directory.
Args:
load_dir: Directory containing the saved module
device: Device to load the model on
Returns:
Initialized KananaEmbeddingWrapper
"""
# Load configuration
config_path = os.path.join(load_dir, "kanana_embedding_config.json")
with open(config_path, 'r', encoding='utf-8') as f:
config = json.load(f)
# Use the saved model path
model_load_path = os.path.join(load_dir, "kanana2vec")
# Create instance with saved configuration
instance = KananaEmbeddingWrapper(
model_name_or_path=model_load_path,
trust_remote_code=config.get("trust_remote_code", True),
device=device # Use the provided device or default
)
return instance |