File size: 7,831 Bytes
0119b51 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
import logging
from typing import Dict, Optional, List, Tuple
import os
import heapq
import json
import logging
import os
import queue
import sys
import time
from tqdm import tqdm
import torch
from collections import defaultdict
from torch.utils.data._utils.worker import ManagerWatchdog
import numpy as np
import torch.distributed as dist
from torch import nn, Tensor
import torch.nn.functional as F
from transformers import AutoModel, AutoTokenizer
from transformers.file_utils import ModelOutput
logger = logging.getLogger(__name__)
class GTEEmbeddidng(nn.Module):
def __init__(self,
model_name: str = None,
normalized: bool = True,
pooling_method: str = 'cls',
use_fp16: bool = True,
device: str = None
):
super().__init__()
self.load_model(model_name)
self.vocab_size = self.model.config.vocab_size
self.normalized = normalized
self.pooling_method = pooling_method
if device:
self.device = torch.device(device)
else:
if torch.cuda.is_available():
self.device = torch.device("cuda")
elif torch.backends.mps.is_available():
self.device = torch.device("mps")
elif is_torch_npu_available():
self.device = torch.device("npu")
else:
self.device = torch.device("cpu")
use_fp16 = False
self.model.to(self.device)
self.sparse_linear.to(self.device)
if use_fp16:
self.model.half()
self.sparse_linear.half()
def load_model(self, model_name):
if not os.path.exists(model_name):
cache_folder = os.getenv('HF_HUB_CACHE')
model_name = snapshot_download(repo_id=model_name,
cache_dir=cache_folder,
ignore_patterns=['flax_model.msgpack', 'rust_model.ot', 'tf_model.h5'])
self.model = AutoModel.from_pretrained(model_name, trust_remote_code=True)
self.sparse_linear = torch.nn.Linear(in_features=self.model.config.hidden_size, out_features=1)
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model.eval()
if os.path.exists(os.path.join(model_name, 'sparse_linear.pt')):
logger.info('loading existing sparse_linear---------')
self.load_pooler(model_dir=model_name)
else:
logger.warring('The parameters of sparse linear is not found')
def dense_embedding(self, hidden_state, mask):
if self.pooling_method == 'cls':
return hidden_state[:, 0]
elif self.pooling_method == 'mean':
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = True):
token_weights = torch.relu(self.sparse_linear(hidden_state))
return token_weights
def _process_token_weights(self, token_weights: np.ndarray, input_ids: list):
# conver to dict
result = defaultdict(int)
unused_tokens = set([self.tokenizer.cls_token_id, self.tokenizer.eos_token_id, self.tokenizer.pad_token_id,
self.tokenizer.unk_token_id])
# token_weights = np.ceil(token_weights * 100)
for w, idx in zip(token_weights, input_ids):
if idx not in unused_tokens and w > 0:
token = self.tokenizer.decode([int(idx)])
if w > result[token]:
result[token] = w
return result
@torch.no_grad()
def encode(self,
texts: None,
dimension: int = None,
max_length: int = 8192,
batch_size: int = 16,
return_dense: bool = True,
return_sparse: bool = False):
if dimension is None:
dimension = self.model.config.hidden_size
if isinstance(texts, str):
texts = [texts]
num_texts = len(texts)
all_dense_vecs = []
all_token_weights = []
for n, i in enumerate(range(0, num_texts, batch_size)):
batch = texts[i: i + batch_size]
resulst = self._encode(batch, dimension, max_length, batch_size, return_dense, return_sparse)
if return_dense:
all_dense_vecs.append(resulst['dense_embeddings'])
if return_sparse:
all_token_weights.extend(resulst['token_weights'])
all_dense_vecs = torch.cat(all_dense_vecs, dim=0)
return {
"dense_embeddings": all_dense_vecs,
"token_weights": all_token_weights
}
@torch.no_grad()
def _encode(self,
texts: Dict[str, Tensor] = None,
dimension: int = None,
max_length: int = 1024,
batch_size: int = 16,
return_dense: bool = True,
return_sparse: bool = False):
text_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors='pt', max_length=max_length)
text_input = {k: v.to(self.model.device) for k,v in text_input.items()}
last_hidden_state = self.model(**text_input, return_dict=True).last_hidden_state
output = {}
if return_dense:
dense_vecs = self.dense_embedding(last_hidden_state, text_input['attention_mask'])
dense_vecs = dense_vecs[:, :dimension]
if self.normalized:
dense_vecs = torch.nn.functional.normalize(dense_vecs, dim=-1)
output['dense_embeddings'] = dense_vecs
if return_sparse:
token_weights = self.sparse_embedding(last_hidden_state, text_input['input_ids']).squeeze(-1)
token_weights = list(map(self._process_token_weights, token_weights.detach().cpu().numpy().tolist(),
text_input['input_ids'].cpu().numpy().tolist()))
output['token_weights'] = token_weights
return output
def load_pooler(self, model_dir):
sparse_state_dict = torch.load(os.path.join(model_dir, 'sparse_linear.pt'), map_location='cpu')
self.sparse_linear.load_state_dict(sparse_state_dict)
def _compute_sparse_scores(self, embs1, embs2):
scores = 0
for token, weight in embs1.items():
if token in embs2:
scores += weight * embs2[token]
return scores
def compute_sparse_scores(self, embs1, embs2):
scores = [self._compute_sparse_scores(emb1, emb2) for emb1, emb2 in zip(embs1, embs2)]
return np.array(scores)
def compute_dense_scores(self, embs1, embs2):
scores = torch.sum(embs1*embs2, dim=-1).cpu().detach().numpy()
return scores
@torch.no_grad()
def compute_scores(self,
text_pairs: List[Tuple[str, str]],
dimension: int = None,
max_length: int = 1024,
batch_size: int = 16,
dense_weight=1.0,
sparse_weight=0.1):
text1_list = [text_pair[0] for text_pair in text_pairs]
text2_list = [text_pair[1] for text_pair in text_pairs]
embs1 = self.encode(text1_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
embs2 = self.encode(text2_list, dimension, max_length, batch_size, return_dense=True, return_sparse=True)
scores = self.compute_dense_scores(embs1['dense_embeddings'], embs2['dense_embeddings']) * dense_weight + \
self.compute_sparse_scores(embs1['token_weights'], embs2['token_weights']) * sparse_weight
scores = scores.tolist()
return scores
|