Spaces:
Sleeping
Sleeping
File size: 4,942 Bytes
9ff79dc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 |
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
class BiEncoderLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
# self.pooling_strategy = pooling_strategy
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
# loss = (loss_rowwise + loss_columnwise) / 2
return loss_rowwise
class ColbertLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
"""
scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
# scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
# for i in range(query_embeddings.shape[0]):
# for j in range(doc_embeddings.shape[0]):
# # step 1 - dot product --> (s1,s2)
# q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
# # step 2 -> max on doc --> (s1)
# q_scores = torch.max(q2d_scores, dim=1)[0]
# # step 3 --> sum the max score --> (1)
# sum_q_score = torch.sum(q_scores)
# # step 4 --> assert is scalar
# scores[i, j] = sum_q_score
# assert (scores_einsum - scores < 0.0001).all().item()
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
# TODO: comparing between queries might not make sense since it's a sum over the length of the query
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
# loss = (loss_rowwise + loss_columnwise) / 2
return loss_rowwise
class ColbertPairwiseCELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
Positive scores are the diagonal of the scores matrix.
"""
# Compute the ColBERT scores
scores = (
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
) # (batch_size, batch_size)
# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)
# Negative score for a given query is the maximum of the scores against all all other pages.
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
# Compute the loss
# The loss is computed as the negative log of the softmax of the positive scores
# relative to the negative scores.
# This can be simplified to log-sum-exp of negative scores minus the positive score
# for numerical stability.
# torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
loss = F.softplus(neg_scores - pos_scores).mean()
return loss
class BiPairwiseCELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
pos_scores = scores.diagonal()
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
neg_scores = neg_scores.max(dim=1)[0]
# Compute the loss
# The loss is computed as the negative log of the softmax of the positive scores
# relative to the negative scores.
# This can be simplified to log-sum-exp of negative scores minus the positive score
# for numerical stability.
loss = F.softplus(neg_scores - pos_scores).mean()
return loss
|