Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -66,23 +66,12 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
66 |
weighted_output = current_output * weight
|
67 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
68 |
|
69 |
-
#
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
device=flat_hidden.device),
|
75 |
-
current_output.unsqueeze(1)))
|
76 |
|
77 |
-
# 标记哪些专家被激活
|
78 |
-
expert_activated.index_copy_(0, top_x,
|
79 |
-
torch.zeros_like(expert_activated[0:top_x.size(0)]).scatter_(
|
80 |
-
1, expert_idx * torch.ones((top_x.size(0), 1),
|
81 |
-
dtype=torch.long,
|
82 |
-
device=flat_hidden.device),
|
83 |
-
torch.ones((top_x.size(0), 1),
|
84 |
-
dtype=torch.bool,
|
85 |
-
device=flat_hidden.device)))
|
86 |
# ---------- 稀疏计算结束 ----------
|
87 |
|
88 |
# ---------- Dense估计部分 ----------
|
@@ -126,7 +115,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
126 |
num_experts = routing_weights.size(1)
|
127 |
device = all_expert_outputs.device
|
128 |
|
129 |
-
#
|
130 |
dense_outputs = torch.zeros((total_tokens, hidden_dim), dtype=all_expert_outputs.dtype, device=device)
|
131 |
|
132 |
# 对每个token单独处理(此处仍需循环,但后续可进一步优化)
|
@@ -191,6 +180,55 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
191 |
|
192 |
return dense_outputs
|
193 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
194 |
|
195 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
196 |
"""
|
|
|
66 |
weighted_output = current_output * weight
|
67 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
68 |
|
69 |
+
# 直接为激活的token分配专家输出
|
70 |
+
for i in range(top_x.shape[0]):
|
71 |
+
token_idx = top_x[i]
|
72 |
+
all_expert_outputs[token_idx, expert_idx] = current_output[i]
|
73 |
+
expert_activated[token_idx, expert_idx] = True
|
|
|
|
|
74 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
75 |
# ---------- 稀疏计算结束 ----------
|
76 |
|
77 |
# ---------- Dense估计部分 ----------
|
|
|
115 |
num_experts = routing_weights.size(1)
|
116 |
device = all_expert_outputs.device
|
117 |
|
118 |
+
# 预分配结果张量,注意是hidden_dim而不是num_experts
|
119 |
dense_outputs = torch.zeros((total_tokens, hidden_dim), dtype=all_expert_outputs.dtype, device=device)
|
120 |
|
121 |
# 对每个token单独处理(此处仍需循环,但后续可进一步优化)
|
|
|
180 |
|
181 |
return dense_outputs
|
182 |
|
183 |
+
def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
|
184 |
+
"""
|
185 |
+
对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
|
186 |
+
参数:
|
187 |
+
token_idx: 当前 token 的索引(标量)
|
188 |
+
activated: 当前 token 激活的专家列表,例如 [1, 3]
|
189 |
+
gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
|
190 |
+
activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
|
191 |
+
all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
|
192 |
+
all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
|
193 |
+
返回:
|
194 |
+
estimated_dense: Tensor, (hidden_dim,)
|
195 |
+
"""
|
196 |
+
num_experts = gate_prob.size(0)
|
197 |
+
dense_parts = {}
|
198 |
+
# 对于激活的专家,直接使用其实际输出
|
199 |
+
for idx in activated:
|
200 |
+
dense_parts[idx] = activated_outputs[idx]
|
201 |
+
# 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
|
202 |
+
non_activated = [i for i in range(num_experts) if i not in activated]
|
203 |
+
for i in non_activated:
|
204 |
+
indices = []
|
205 |
+
for idx, r_dec in enumerate(all_routing):
|
206 |
+
if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
|
207 |
+
indices.append(idx)
|
208 |
+
if indices:
|
209 |
+
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
210 |
+
# 只计算非零值的平均值
|
211 |
+
mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
|
212 |
+
if mask.sum() > 0:
|
213 |
+
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
214 |
+
else:
|
215 |
+
# 如果全是零,返回零向量
|
216 |
+
estimated = torch.zeros_like(selected_outputs[0])
|
217 |
+
else:
|
218 |
+
all_outputs = all_expert_outputs[:, i, :]
|
219 |
+
mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
|
220 |
+
if mask.sum() > 0:
|
221 |
+
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
222 |
+
else:
|
223 |
+
# 如果全是零,返回零向量
|
224 |
+
estimated = torch.zeros_like(all_outputs[0])
|
225 |
+
dense_parts[i] = estimated
|
226 |
+
# 按 gate_prob 加权求和各专家输出
|
227 |
+
estimated_dense = 0
|
228 |
+
for i in range(num_experts):
|
229 |
+
estimated_dense += gate_prob[i] * dense_parts[i]
|
230 |
+
return estimated_dense
|
231 |
+
|
232 |
|
233 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
234 |
"""
|