autoprogrammer commited on
Commit
e1f5244
·
verified ·
1 Parent(s): 8658bac

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +55 -17
modeling_densebackward_olmoe0125.py CHANGED
@@ -66,23 +66,12 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
66
  weighted_output = current_output * weight
67
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
68
 
69
- # 保存专家输出到张量中,而不是使用字典
70
- all_expert_outputs.index_copy_(0, top_x,
71
- torch.zeros_like(all_expert_outputs[0:top_x.size(0)]).scatter_(
72
- 1, expert_idx * torch.ones((top_x.size(0), 1),
73
- dtype=torch.long,
74
- device=flat_hidden.device),
75
- current_output.unsqueeze(1)))
76
 
77
- # 标记哪些专家被激活
78
- expert_activated.index_copy_(0, top_x,
79
- torch.zeros_like(expert_activated[0:top_x.size(0)]).scatter_(
80
- 1, expert_idx * torch.ones((top_x.size(0), 1),
81
- dtype=torch.long,
82
- device=flat_hidden.device),
83
- torch.ones((top_x.size(0), 1),
84
- dtype=torch.bool,
85
- device=flat_hidden.device)))
86
  # ---------- 稀疏计算结束 ----------
87
 
88
  # ---------- Dense估计部分 ----------
@@ -126,7 +115,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
126
  num_experts = routing_weights.size(1)
127
  device = all_expert_outputs.device
128
 
129
- # 预分配结果张量
130
  dense_outputs = torch.zeros((total_tokens, hidden_dim), dtype=all_expert_outputs.dtype, device=device)
131
 
132
  # 对每个token单独处理(此处仍需循环,但后续可进一步优化)
@@ -191,6 +180,55 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
191
 
192
  return dense_outputs
193
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
194
 
195
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
196
  """
 
66
  weighted_output = current_output * weight
67
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
68
 
69
+ # 直接为激活的token分配专家输出
70
+ for i in range(top_x.shape[0]):
71
+ token_idx = top_x[i]
72
+ all_expert_outputs[token_idx, expert_idx] = current_output[i]
73
+ expert_activated[token_idx, expert_idx] = True
 
 
74
 
 
 
 
 
 
 
 
 
 
75
  # ---------- 稀疏计算结束 ----------
76
 
77
  # ---------- Dense估计部分 ----------
 
115
  num_experts = routing_weights.size(1)
116
  device = all_expert_outputs.device
117
 
118
+ # 预分配结果张量,注意是hidden_dim而不是num_experts
119
  dense_outputs = torch.zeros((total_tokens, hidden_dim), dtype=all_expert_outputs.dtype, device=device)
120
 
121
  # 对每个token单独处理(此处仍需循环,但后续可进一步优化)
 
180
 
181
  return dense_outputs
182
 
183
+ def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
184
+ """
185
+ 对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
186
+ 参数:
187
+ token_idx: 当前 token 的索引(标量)
188
+ activated: 当前 token 激活的专家列表,例如 [1, 3]
189
+ gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
190
+ activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
191
+ all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
192
+ all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
193
+ 返回:
194
+ estimated_dense: Tensor, (hidden_dim,)
195
+ """
196
+ num_experts = gate_prob.size(0)
197
+ dense_parts = {}
198
+ # 对于激活的专家,直接使用其实际输出
199
+ for idx in activated:
200
+ dense_parts[idx] = activated_outputs[idx]
201
+ # 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
202
+ non_activated = [i for i in range(num_experts) if i not in activated]
203
+ for i in non_activated:
204
+ indices = []
205
+ for idx, r_dec in enumerate(all_routing):
206
+ if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
207
+ indices.append(idx)
208
+ if indices:
209
+ selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
210
+ # 只计算非零值的平均值
211
+ mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
212
+ if mask.sum() > 0:
213
+ estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
214
+ else:
215
+ # 如果全是零,返回零向量
216
+ estimated = torch.zeros_like(selected_outputs[0])
217
+ else:
218
+ all_outputs = all_expert_outputs[:, i, :]
219
+ mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
220
+ if mask.sum() > 0:
221
+ estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
222
+ else:
223
+ # 如果全是零,返回零向量
224
+ estimated = torch.zeros_like(all_outputs[0])
225
+ dense_parts[i] = estimated
226
+ # 按 gate_prob 加权求和各专家输出
227
+ estimated_dense = 0
228
+ for i in range(num_experts):
229
+ estimated_dense += gate_prob[i] * dense_parts[i]
230
+ return estimated_dense
231
+
232
 
233
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
234
  """