Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -137,7 +137,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
137 |
if indices:
|
138 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
139 |
# 只计算非零值的平均值
|
140 |
-
mask = (selected_outputs.sum(dim=-1) != 0).
|
141 |
if mask.sum() > 0:
|
142 |
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
143 |
else:
|
@@ -145,7 +145,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
145 |
estimated = torch.zeros_like(selected_outputs[0])
|
146 |
else:
|
147 |
all_outputs = all_expert_outputs[:, i, :]
|
148 |
-
mask = (all_outputs.sum(dim=-1) != 0).
|
149 |
if mask.sum() > 0:
|
150 |
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
151 |
else:
|
|
|
137 |
if indices:
|
138 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
139 |
# 只计算非零值的平均值
|
140 |
+
mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
|
141 |
if mask.sum() > 0:
|
142 |
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
143 |
else:
|
|
|
145 |
estimated = torch.zeros_like(selected_outputs[0])
|
146 |
else:
|
147 |
all_outputs = all_expert_outputs[:, i, :]
|
148 |
+
mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
|
149 |
if mask.sum() > 0:
|
150 |
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
151 |
else:
|