Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -42,7 +42,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
42 |
all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),
|
43 |
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
44 |
|
45 |
-
for expert_idx in
|
46 |
expert_layer = self.experts[expert_idx]
|
47 |
# 对所有token都计算当前专家的输出
|
48 |
expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
|
|
|
42 |
all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),
|
43 |
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
44 |
|
45 |
+
for expert_idx in range(self.num_experts):
|
46 |
expert_layer = self.experts[expert_idx]
|
47 |
# 对所有token都计算当前专家的输出
|
48 |
expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
|