HUANG-Stephanie's picture
Upload 88 files
9ff79dc verified
import torch
import torch.nn.functional as F
from torch.nn import CrossEntropyLoss
class BiEncoderLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
# self.pooling_strategy = pooling_strategy
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
# loss = (loss_rowwise + loss_columnwise) / 2
return loss_rowwise
class ColbertLoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
"""
scores = torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
# scores = torch.zeros((query_embeddings.shape[0], doc_embeddings.shape[0]), device=query_embeddings.device)
# for i in range(query_embeddings.shape[0]):
# for j in range(doc_embeddings.shape[0]):
# # step 1 - dot product --> (s1,s2)
# q2d_scores = torch.matmul(query_embeddings[i], doc_embeddings[j].T)
# # step 2 -> max on doc --> (s1)
# q_scores = torch.max(q2d_scores, dim=1)[0]
# # step 3 --> sum the max score --> (1)
# sum_q_score = torch.sum(q_scores)
# # step 4 --> assert is scalar
# scores[i, j] = sum_q_score
# assert (scores_einsum - scores < 0.0001).all().item()
loss_rowwise = self.ce_loss(scores, torch.arange(scores.shape[0], device=scores.device))
# TODO: comparing between queries might not make sense since it's a sum over the length of the query
# loss_columnwise = self.ce_loss(scores.T, torch.arange(scores.shape[1], device=scores.device))
# loss = (loss_rowwise + loss_columnwise) / 2
return loss_rowwise
class ColbertPairwiseCELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, num_query_tokens, dim)
doc_embeddings: (batch_size, num_doc_tokens, dim)
Positive scores are the diagonal of the scores matrix.
"""
# Compute the ColBERT scores
scores = (
torch.einsum("bnd,csd->bcns", query_embeddings, doc_embeddings).max(dim=3)[0].sum(dim=2)
) # (batch_size, batch_size)
# Positive scores are the diagonal of the scores matrix.
pos_scores = scores.diagonal() # (batch_size,)
# Negative score for a given query is the maximum of the scores against all all other pages.
# NOTE: We exclude the diagonal by setting it to a very low value: since we know the maximum score is 1,
# we can subtract 1 from the diagonal to exclude it from the maximum operation.
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6 # (batch_size, batch_size)
neg_scores = neg_scores.max(dim=1)[0] # (batch_size,)
# Compute the loss
# The loss is computed as the negative log of the softmax of the positive scores
# relative to the negative scores.
# This can be simplified to log-sum-exp of negative scores minus the positive score
# for numerical stability.
# torch.vstack((pos_scores, neg_scores)).T.softmax(1)[:, 0].log()*(-1)
loss = F.softplus(neg_scores - pos_scores).mean()
return loss
class BiPairwiseCELoss(torch.nn.Module):
def __init__(self):
super().__init__()
self.ce_loss = CrossEntropyLoss()
def forward(self, query_embeddings, doc_embeddings):
"""
query_embeddings: (batch_size, dim)
doc_embeddings: (batch_size, dim)
"""
scores = torch.einsum("bd,cd->bc", query_embeddings, doc_embeddings)
pos_scores = scores.diagonal()
neg_scores = scores - torch.eye(scores.shape[0], device=scores.device) * 1e6
neg_scores = neg_scores.max(dim=1)[0]
# Compute the loss
# The loss is computed as the negative log of the softmax of the positive scores
# relative to the negative scores.
# This can be simplified to log-sum-exp of negative scores minus the positive score
# for numerical stability.
loss = F.softplus(neg_scores - pos_scores).mean()
return loss