Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -65,21 +65,21 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
65 |
num_experts = self.num_experts
|
66 |
|
67 |
# 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
|
68 |
-
R = F.one_hot(selected_experts, num_classes=num_experts).
|
69 |
R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
|
70 |
|
71 |
# 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
|
72 |
S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
|
73 |
-
S = S * (1 - torch.eye(N_tokens, device=S.device)) # 去除自身
|
74 |
|
75 |
# 构造候选 mask M: (N_tokens, N_tokens, num_experts)
|
76 |
# M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
|
77 |
R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
|
78 |
S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
|
79 |
-
candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).
|
80 |
|
81 |
# 对于数值稳定,排除 token 自身(对角线置0)
|
82 |
-
candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device).unsqueeze(-1))
|
83 |
|
84 |
# 扩展 mask 和 all_expert_outputs 以便批量聚合
|
85 |
# all_expert_outputs: (N_tokens, num_experts, hidden_dim)
|
|
|
65 |
num_experts = self.num_experts
|
66 |
|
67 |
# 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
|
68 |
+
R = F.one_hot(selected_experts, num_classes=num_experts).to(flat_hidden.dtype) # (N_tokens, top_k, num_experts)
|
69 |
R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
|
70 |
|
71 |
# 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
|
72 |
S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
|
73 |
+
S = S * (1 - torch.eye(N_tokens, device=S.device, dtype=flat_hidden.dtype)) # 去除自身
|
74 |
|
75 |
# 构造候选 mask M: (N_tokens, N_tokens, num_experts)
|
76 |
# M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
|
77 |
R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
|
78 |
S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
|
79 |
+
candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).to(flat_hidden.dtype) # (N_tokens, N_tokens, num_experts)
|
80 |
|
81 |
# 对于数值稳定,排除 token 自身(对角线置0)
|
82 |
+
candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device, dtype=flat_hidden.dtype).unsqueeze(-1))
|
83 |
|
84 |
# 扩展 mask 和 all_expert_outputs 以便批量聚合
|
85 |
# all_expert_outputs: (N_tokens, num_experts, hidden_dim)
|