Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -89,7 +89,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
89 |
# 对每个 token i 和专家 e,聚合候选 token 的输出
|
90 |
sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
|
91 |
count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
92 |
-
estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / count_outputs,
|
93 |
torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
|
94 |
|
95 |
# 对于激活的专家,直接使用当前 token 的输出
|
|
|
89 |
# 对每个 token i 和专家 e,聚合候选 token 的输出
|
90 |
sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
|
91 |
count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
92 |
+
estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / (count_outputs+1),
|
93 |
torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
|
94 |
|
95 |
# 对于激活的专家,直接使用当前 token 的输出
|