Spaces:
Sleeping
Sleeping
import torch | |
def get_local_score(q_reps, p_reps, all_scores): | |
"""获取 queries 和 passages 的局部得分。 | |
Args: | |
q_reps (torch.Tensor): queries 的表示。 | |
p_reps (torch.Tensor): passages 的表示。 | |
all_scores (torch.Tensor): 计算得到的所有 query-passage 的得分。 | |
Returns: | |
torch.Tensor: 用于计算损失的局部得分。 | |
""" | |
group_size = p_reps.size(0) // q_reps.size(0) # 每个 query 对应的 passages 数量 | |
indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size # 每个 query 在 all_scores 中的索引 | |
specific_scores = [] | |
for i in range(group_size): | |
# 从 all_scores 中提取每个 query 对应的第 i 个 passage 的得分 | |
specific_scores.append( | |
all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i] # (batch_size, group_size) | |
) | |
# 将所有特定得分堆叠在一起,并调整形状为 (batch_size, group_size) | |
return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1) | |
def _compute_similarity(q_reps, p_reps): | |
"""使用内积计算 query 和 passage 表示之间的相似度。 | |
Args: | |
q_reps (torch.Tensor): queries 的表示。 | |
p_reps (torch.Tensor): passages 的表示。 | |
Returns: | |
torch.Tensor: 计算得到的相似度矩阵。 | |
""" | |
if len(p_reps.size()) == 2: | |
return torch.matmul(q_reps, p_reps.transpose(0, 1)) | |
return torch.matmul(q_reps, p_reps.transpose(-2, -1)) | |
def compute_score(q_reps, p_reps, temperature): | |
"""计算 queries 和 passages 之间的得分。 | |
Args: | |
q_reps (torch.Tensor): queries 的表示。 | |
p_reps (torch.Tensor): passages 的表示。 | |
temperature (float): 温度参数,用于调整得分。 | |
Returns: | |
torch.Tensor: 调整后的得分。 | |
""" | |
scores = _compute_similarity(q_reps, p_reps) / temperature # (batch_size, group_size) | |
scores = scores.view(q_reps.size(0), -1) # (batch_size, group_size) | |
return scores | |
def compute_local_score(q_reps, p_reps, temperature): | |
"""计算 queries 和 passages 的局部得分。 | |
Args: | |
q_reps (torch.Tensor): queries 的表示。 | |
p_reps (torch.Tensor): passages 的表示。 | |
temperature (float): 温度参数,用于调整得分。 | |
Returns: | |
torch.Tensor: 用于计算损失的局部得分。 | |
""" | |
all_scores = compute_score(q_reps, p_reps, temperature) | |
loacl_scores = get_local_score(q_reps, p_reps, all_scores) | |
return loacl_scores | |
def compute_loss(scores, target): | |
"""使用交叉熵计算损失。 | |
Args: | |
scores (torch.Tensor): 计算得到的得分。 | |
target (torch.Tensor): 目标值。 | |
Returns: | |
torch.Tensor: 计算得到的交叉熵损失。 | |
""" | |
cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean') | |
return cross_entropy(scores, target) | |
def compute_no_in_batch_neg_loss(q_reps, p_reps, temperature): | |
""" | |
在不使用批内负样本和跨设备负样本的情况下计算损失。 | |
Args: | |
q_reps (torch.Tensor): queries 的表示,形状为 (batch_size, dim)。 | |
p_reps (torch.Tensor): passages 的表示,形状为 (batch_size * group_size, dim)。 | |
temperature (float): 温度参数,用于调整得分。 | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: 返回局部得分和计算得到的损失。 | |
""" | |
local_scores = compute_local_score(q_reps, p_reps, temperature) # (batch_size, group_size) | |
local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long) # (batch_size) | |
loss = compute_loss(local_scores, local_targets) | |
return local_scores, loss |