autoprogrammer commited on
Commit
0314bc4
·
verified ·
1 Parent(s): a71d5b3

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +21 -3
modeling_densebackward_olmoe0125.py CHANGED
@@ -80,7 +80,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
80
 
81
  # ---------- Dense估计部分 ----------
82
  # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
- all_expert_outputs = torch.stack([expert(flat_hidden) for expert in self.experts], dim=1)
 
 
 
 
 
 
84
  # 将 selected_experts 转换为 list,每个 token 的激活专家列表
85
  all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
86
 
@@ -130,9 +136,21 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
130
  indices.append(idx)
131
  if indices:
132
  selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
133
- estimated = selected_outputs.mean(dim=0)
 
 
 
 
 
 
134
  else:
135
- estimated = all_expert_outputs[:, i, :].mean(dim=0)
 
 
 
 
 
 
136
  dense_parts[i] = estimated
137
  # 按 gate_prob 加权求和各专家输出
138
  estimated_dense = 0
 
80
 
81
  # ---------- Dense估计部分 ----------
82
  # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
+ # 创建全零张量,只填入已激活专家的输出
84
+ all_expert_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
85
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
86
+ # 填入已激活专家的输出
87
+ for i in range(flat_hidden.size(0)):
88
+ for expert_idx in activated_outputs[i].keys():
89
+ all_expert_outputs[i, expert_idx] = activated_outputs[i][expert_idx]
90
  # 将 selected_experts 转换为 list,每个 token 的激活专家列表
91
  all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
92
 
 
136
  indices.append(idx)
137
  if indices:
138
  selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
139
+ # 只计算非零值的平均值
140
+ mask = (selected_outputs.sum(dim=-1) != 0).float().unsqueeze(-1) # (n, 1)
141
+ if mask.sum() > 0:
142
+ estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
143
+ else:
144
+ # 如果全是零,返回零向量
145
+ estimated = torch.zeros_like(selected_outputs[0])
146
  else:
147
+ all_outputs = all_expert_outputs[:, i, :]
148
+ mask = (all_outputs.sum(dim=-1) != 0).float().unsqueeze(-1) # (N, 1)
149
+ if mask.sum() > 0:
150
+ estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
151
+ else:
152
+ # 如果全是零,返回零向量
153
+ estimated = torch.zeros_like(all_outputs[0])
154
  dense_parts[i] = estimated
155
  # 按 gate_prob 加权求和各专家输出
156
  estimated_dense = 0