Update modelling_deepseek.py
Browse files- modelling_deepseek.py +27 -4
modelling_deepseek.py
CHANGED
@@ -265,7 +265,30 @@ class MoEGate(nn.Module):
|
|
265 |
topk_weight = topk_weight / denominator
|
266 |
|
267 |
# Expert-level computation auxiliary loss
|
268 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
269 |
return topk_idx, topk_weight.to(hidden_states.dtype), aux_loss
|
270 |
|
271 |
|
@@ -314,11 +337,11 @@ class DeepseekMoE(nn.Module):
|
|
314 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
315 |
flat_topk_idx = topk_idx.view(-1)
|
316 |
if self.training:
|
317 |
-
y = self.moe_train(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)
|
318 |
y = y.view(*orig_shape)
|
319 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
320 |
else:
|
321 |
-
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)
|
322 |
if self.config.n_shared_experts is not None:
|
323 |
y = y + self.shared_experts(identity)
|
324 |
return y
|
@@ -329,7 +352,7 @@ class DeepseekMoE(nn.Module):
|
|
329 |
for i, expert in enumerate(self.experts):
|
330 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
331 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
332 |
-
return y
|
333 |
|
334 |
@torch.no_grad()
|
335 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|
|
|
265 |
topk_weight = topk_weight / denominator
|
266 |
|
267 |
# Expert-level computation auxiliary loss
|
268 |
+
# (was absent before)
|
269 |
+
if self.training and self.alpha > 0.0:
|
270 |
+
scores_for_aux = scores
|
271 |
+
aux_topk = self.top_k
|
272 |
+
# always compute aux loss based on the naive greedy topk method
|
273 |
+
topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
|
274 |
+
if self.seq_aux:
|
275 |
+
scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
|
276 |
+
ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device, dtype=torch.float32)
|
277 |
+
ce.scatter_add_(
|
278 |
+
1,
|
279 |
+
topk_idx_for_aux_loss,
|
280 |
+
torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device, dtype=torch.float32)
|
281 |
+
)
|
282 |
+
ce.div_(seq_len * aux_topk / self.n_routed_experts)
|
283 |
+
aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
|
284 |
+
else:
|
285 |
+
mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
|
286 |
+
ce = mask_ce.float().mean(0)
|
287 |
+
Pi = scores_for_aux.mean(0)
|
288 |
+
fi = ce * self.n_routed_experts
|
289 |
+
aux_loss = (Pi * fi).sum() * self.alpha
|
290 |
+
else:
|
291 |
+
aux_loss = None
|
292 |
return topk_idx, topk_weight.to(hidden_states.dtype), aux_loss
|
293 |
|
294 |
|
|
|
337 |
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
338 |
flat_topk_idx = topk_idx.view(-1)
|
339 |
if self.training:
|
340 |
+
y = self.moe_train(hidden_states, flat_topk_idx, topk_weight) # removed unnecessary .view(-1, 1)
|
341 |
y = y.view(*orig_shape)
|
342 |
y = AddAuxiliaryLoss.apply(y, aux_loss)
|
343 |
else:
|
344 |
+
y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight).view(*orig_shape) # removed unnecessary .view(-1, 1)
|
345 |
if self.config.n_shared_experts is not None:
|
346 |
y = y + self.shared_experts(identity)
|
347 |
return y
|
|
|
352 |
for i, expert in enumerate(self.experts):
|
353 |
y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
|
354 |
y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
|
355 |
+
return y.to(hidden_states.dtype) # .sum() in previous line returns fp32 tensor
|
356 |
|
357 |
@torch.no_grad()
|
358 |
def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
|