File size: 3,783 Bytes
0f5c20a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
import torch


def get_local_score(q_reps, p_reps, all_scores):
    """获取 queries 和 passages 的局部得分。

    Args:
        q_reps (torch.Tensor): queries 的表示。
        p_reps (torch.Tensor): passages 的表示。
        all_scores (torch.Tensor): 计算得到的所有 query-passage 的得分。

    Returns:
        torch.Tensor: 用于计算损失的局部得分。
    """
    group_size = p_reps.size(0) // q_reps.size(0) # 每个 query 对应的 passages 数量
    indices = torch.arange(0, q_reps.size(0), device=q_reps.device) * group_size # 每个 query 在 all_scores 中的索引
    specific_scores = []
    for i in range(group_size):
        # 从 all_scores 中提取每个 query 对应的第 i 个 passage 的得分
        specific_scores.append(
            all_scores[torch.arange(q_reps.size(0), device=q_reps.device), indices + i] # (batch_size, group_size)
        )
    # 将所有特定得分堆叠在一起,并调整形状为 (batch_size, group_size)
    return torch.stack(specific_scores, dim=1).view(q_reps.size(0), -1)





def _compute_similarity(q_reps, p_reps):
    """使用内积计算 query 和 passage 表示之间的相似度。

    Args:
        q_reps (torch.Tensor): queries 的表示。
        p_reps (torch.Tensor): passages 的表示。

    Returns:
        torch.Tensor: 计算得到的相似度矩阵。
    """
    if len(p_reps.size()) == 2:
        return torch.matmul(q_reps, p_reps.transpose(0, 1))
    return torch.matmul(q_reps, p_reps.transpose(-2, -1))


def compute_score(q_reps, p_reps, temperature):
    """计算 queries 和 passages 之间的得分。

    Args:
        q_reps (torch.Tensor): queries 的表示。
        p_reps (torch.Tensor): passages 的表示。
        temperature (float): 温度参数,用于调整得分。

    Returns:
        torch.Tensor: 调整后的得分。
    """
    scores = _compute_similarity(q_reps, p_reps) / temperature  # (batch_size, group_size)
    scores = scores.view(q_reps.size(0), -1)  # (batch_size, group_size)
    return scores


def compute_local_score(q_reps, p_reps, temperature):
    """计算 queries 和 passages 的局部得分。

    Args:
        q_reps (torch.Tensor): queries 的表示。
        p_reps (torch.Tensor): passages 的表示。
        temperature (float): 温度参数,用于调整得分。

    Returns:
        torch.Tensor: 用于计算损失的局部得分。
    """
    all_scores = compute_score(q_reps, p_reps, temperature)

    loacl_scores = get_local_score(q_reps, p_reps, all_scores)
    return loacl_scores



def compute_loss(scores, target):
    """使用交叉熵计算损失。

    Args:
        scores (torch.Tensor): 计算得到的得分。
        target (torch.Tensor): 目标值。

    Returns:
        torch.Tensor: 计算得到的交叉熵损失。
    """
    cross_entropy = torch.nn.CrossEntropyLoss(reduction='mean')
    return cross_entropy(scores, target)


def compute_no_in_batch_neg_loss(q_reps, p_reps, temperature):
    """
    在不使用批内负样本和跨设备负样本的情况下计算损失。
    
    Args:
        q_reps (torch.Tensor): queries 的表示,形状为 (batch_size, dim)。
        p_reps (torch.Tensor): passages 的表示,形状为 (batch_size * group_size, dim)。
        temperature (float): 温度参数,用于调整得分。

    Returns:
        Tuple[torch.Tensor, torch.Tensor]: 返回局部得分和计算得到的损失。
    """
    local_scores = compute_local_score(q_reps, p_reps, temperature)   # (batch_size, group_size)
    local_targets = torch.zeros(local_scores.size(0), device=local_scores.device, dtype=torch.long)  # (batch_size)
    
    loss = compute_loss(local_scores, local_targets)

    return local_scores, loss