Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -80,7 +80,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
80 |
|
81 |
# ---------- Dense估计部分 ----------
|
82 |
# 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
|
83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
84 |
# 将 selected_experts 转换为 list,每个 token 的激活专家列表
|
85 |
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
|
86 |
|
@@ -130,9 +136,21 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
130 |
indices.append(idx)
|
131 |
if indices:
|
132 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
else:
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
136 |
dense_parts[i] = estimated
|
137 |
# 按 gate_prob 加权求和各专家输出
|
138 |
estimated_dense = 0
|
|
|
80 |
|
81 |
# ---------- Dense估计部分 ----------
|
82 |
# 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
|
83 |
+
# 创建全零张量,只填入已激活专家的输出
|
84 |
+
all_expert_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
|
85 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
86 |
+
# 填入已激活专家的输出
|
87 |
+
for i in range(flat_hidden.size(0)):
|
88 |
+
for expert_idx in activated_outputs[i].keys():
|
89 |
+
all_expert_outputs[i, expert_idx] = activated_outputs[i][expert_idx]
|
90 |
# 将 selected_experts 转换为 list,每个 token 的激活专家列表
|
91 |
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
|
92 |
|
|
|
136 |
indices.append(idx)
|
137 |
if indices:
|
138 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
139 |
+
# 只计算非零值的平均值
|
140 |
+
mask = (selected_outputs.sum(dim=-1) != 0).float().unsqueeze(-1) # (n, 1)
|
141 |
+
if mask.sum() > 0:
|
142 |
+
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
143 |
+
else:
|
144 |
+
# 如果全是零,返回零向量
|
145 |
+
estimated = torch.zeros_like(selected_outputs[0])
|
146 |
else:
|
147 |
+
all_outputs = all_expert_outputs[:, i, :]
|
148 |
+
mask = (all_outputs.sum(dim=-1) != 0).float().unsqueeze(-1) # (N, 1)
|
149 |
+
if mask.sum() > 0:
|
150 |
+
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
151 |
+
else:
|
152 |
+
# 如果全是零,返回零向量
|
153 |
+
estimated = torch.zeros_like(all_outputs[0])
|
154 |
dense_parts[i] = estimated
|
155 |
# 按 gate_prob 加权求和各专家输出
|
156 |
estimated_dense = 0
|