autoprogrammer commited on
Commit
9f79841
·
verified ·
1 Parent(s): 7233810

Update modeling_densebackward_olmoe0125.py

Browse files
modeling_densebackward_olmoe0125.py CHANGED
@@ -65,21 +65,21 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
65
  num_experts = self.num_experts
66
 
67
  # 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
68
- R = F.one_hot(selected_experts, num_classes=num_experts).float() # (N_tokens, top_k, num_experts)
69
  R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
70
 
71
  # 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
72
  S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
73
- S = S * (1 - torch.eye(N_tokens, device=S.device)) # 去除自身
74
 
75
  # 构造候选 mask M: (N_tokens, N_tokens, num_experts)
76
  # M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
77
  R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
78
  S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
79
- candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).float() # (N_tokens, N_tokens, num_experts)
80
 
81
  # 对于数值稳定,排除 token 自身(对角线置0)
82
- candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device).unsqueeze(-1))
83
 
84
  # 扩展 mask 和 all_expert_outputs 以便批量聚合
85
  # all_expert_outputs: (N_tokens, num_experts, hidden_dim)
 
65
  num_experts = self.num_experts
66
 
67
  # 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
68
+ R = F.one_hot(selected_experts, num_classes=num_experts).to(flat_hidden.dtype) # (N_tokens, top_k, num_experts)
69
  R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
70
 
71
  # 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
72
  S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
73
+ S = S * (1 - torch.eye(N_tokens, device=S.device, dtype=flat_hidden.dtype)) # 去除自身
74
 
75
  # 构造候选 mask M: (N_tokens, N_tokens, num_experts)
76
  # M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
77
  R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
78
  S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
79
+ candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).to(flat_hidden.dtype) # (N_tokens, N_tokens, num_experts)
80
 
81
  # 对于数值稳定,排除 token 自身(对角线置0)
82
+ candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device, dtype=flat_hidden.dtype).unsqueeze(-1))
83
 
84
  # 扩展 mask 和 all_expert_outputs 以便批量聚合
85
  # all_expert_outputs: (N_tokens, num_experts, hidden_dim)