bge-m3-hf / modeling_bge_m3.py
liuyanyi's picture
Update modeling_bge_m3.py
9379593 verified
raw
history blame
7.52 kB
from collections import defaultdict
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple, Union
import torch
from torch import nn
from transformers.modeling_outputs import BaseModelOutputWithPoolingAndCrossAttentions, ModelOutput
from transformers.models.xlm_roberta import (
XLMRobertaModel,
XLMRobertaPreTrainedModel,
)
from .configuration_bge_m3 import BgeM3Config
@dataclass
class BgeM3ModelOutput(ModelOutput):
last_hidden_state: torch.FloatTensor = None
pooler_output: torch.FloatTensor = None
dense_output: torch.FloatTensor = None
colbert_output: Optional[List[torch.FloatTensor]] = None
sparse_output: Optional[Dict[int, float]] = None
hidden_states: Optional[Tuple[torch.FloatTensor]] = None
past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
attentions: Optional[Tuple[torch.FloatTensor]] = None
cross_attentions: Optional[Tuple[torch.FloatTensor]] = None
class BgeM3Model(XLMRobertaPreTrainedModel):
config_class = BgeM3Config
def __init__(self, config: BgeM3Config):
super().__init__(config)
self.roberta = XLMRobertaModel(config, add_pooling_layer=False)
# TODO: Check the dtype of these linear layers
self.colbert_linear = nn.Linear(
in_features=config.hidden_size,
out_features=config.hidden_size if config.colbert_dim is None else config.colbert_dim,
)
self.sparse_linear = nn.Linear(in_features=config.hidden_size, out_features=1)
self.sentence_pooling_method = config.sentence_pooling_method
self.init_weights()
# Copied from FlagEmbedding
def dense_embedding(self, hidden_state, mask):
if self.sentence_pooling_method == "cls":
return hidden_state[:, 0]
elif self.sentence_pooling_method == "mean":
s = torch.sum(hidden_state * mask.unsqueeze(-1).float(), dim=1)
d = mask.sum(axis=1, keepdim=True).float()
return s / d
# Copied from FlagEmbedding
def sparse_embedding(self, hidden_state, input_ids, return_embedding: bool = False):
token_weights = torch.relu(self.sparse_linear(hidden_state))
if not return_embedding:
return token_weights
sparse_embedding = torch.zeros(
input_ids.size(0),
input_ids.size(1),
self.config.vocab_size,
dtype=token_weights.dtype,
device=token_weights.device,
)
sparse_embedding = torch.scatter(sparse_embedding, dim=-1, index=input_ids.unsqueeze(-1), src=token_weights)
unused_tokens = self.config.unused_tokens
sparse_embedding = torch.max(sparse_embedding, dim=1).values
sparse_embedding[:, unused_tokens] *= 0.0
return sparse_embedding
# Copied from FlagEmbedding
def colbert_embedding(self, last_hidden_state, mask):
colbert_vecs = self.colbert_linear(last_hidden_state[:, 1:])
colbert_vecs = colbert_vecs * mask[:, 1:][:, :, None].float()
return colbert_vecs
# Modified from FlagEmbedding
def _process_token_weights(self, token_weights, input_ids, mask):
token_weights = token_weights.squeeze(-1)
# conver to dict
all_result = []
unused_tokens = self.config.unused_tokens
unused_tokens = torch.tensor(unused_tokens, device=input_ids.device)
# Get valid matrix
valid_indices = ~torch.isin(input_ids, unused_tokens)
# w>0
valid_indices = (valid_indices & (token_weights > 0)).bool()
valid_indices = (valid_indices & mask).bool()
for i, valid in enumerate(valid_indices):
result = defaultdict(int)
# Get valid weight and ids
valid_weights = token_weights[i][valid]
valid_ids = input_ids[i][valid]
# Get unique token
unique_ids, inverse_indices = torch.unique(valid_ids, return_inverse=True)
# Get max weight for each token
for i in range(unique_ids.shape[0]):
id_mask = inverse_indices == i
result[str(unique_ids[i].item())] = valid_weights[id_mask].max().item()
all_result.append(result)
return all_result
# Copied from FlagEmbedding
def _process_colbert_vecs(self, colbert_vecs, tokens_num) -> List[torch.Tensor]:
# delte the vectors of padding tokens
vecs = []
for i in range(len(tokens_num)):
vecs.append(colbert_vecs[i, : tokens_num[i] - 1])
return vecs
# Copied from transformers.models.bert.modeling_bert.BertModel.forward
def forward(
self,
input_ids: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
token_type_ids: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
head_mask: Optional[torch.Tensor] = None,
inputs_embeds: Optional[torch.Tensor] = None,
encoder_hidden_states: Optional[torch.Tensor] = None,
encoder_attention_mask: Optional[torch.Tensor] = None,
past_key_values: Optional[List[torch.FloatTensor]] = None,
use_cache: Optional[bool] = None,
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
) -> Union[Tuple[torch.Tensor], BgeM3ModelOutput]:
roberta_output: BaseModelOutputWithPoolingAndCrossAttentions = self.roberta(
input_ids,
attention_mask=attention_mask,
token_type_ids=token_type_ids,
position_ids=position_ids,
head_mask=head_mask,
inputs_embeds=inputs_embeds,
encoder_hidden_states=encoder_hidden_states,
encoder_attention_mask=encoder_attention_mask,
past_key_values=past_key_values,
use_cache=use_cache,
output_attentions=output_attentions,
output_hidden_states=output_hidden_states,
return_dict=True,
)
last_hidden_state = roberta_output.last_hidden_state
dense_output = self.dense_embedding(last_hidden_state, attention_mask)
tokens_num = attention_mask.sum(dim=1)
colbert_output = self.colbert_embedding(last_hidden_state, attention_mask)
colbert_output = self._process_colbert_vecs(colbert_output, tokens_num)
sparse_output = self.sparse_embedding(last_hidden_state, input_ids)
sparse_output = self._process_token_weights(sparse_output, input_ids, attention_mask)
if not return_dict:
return (
last_hidden_state,
roberta_output.pooler_output,
dense_output,
colbert_output,
sparse_output,
roberta_output.hidden_states,
roberta_output.past_key_values,
roberta_output.attentions,
roberta_output.cross_attentions,
)
return BgeM3ModelOutput(
last_hidden_state=last_hidden_state,
dense_output=dense_output,
pooler_output=roberta_output.pooler_output,
colbert_output=colbert_output,
sparse_output=sparse_output,
hidden_states=roberta_output.hidden_states,
past_key_values=roberta_output.past_key_values,
attentions=roberta_output.attentions,
cross_attentions=roberta_output.cross_attentions,
)