vltnmmdv commited on
Commit
7c4e599
·
verified ·
1 Parent(s): a32fc22

Update modelling_deepseek.py

Browse files
Files changed (1) hide show
  1. modelling_deepseek.py +27 -4
modelling_deepseek.py CHANGED
@@ -265,7 +265,30 @@ class MoEGate(nn.Module):
265
  topk_weight = topk_weight / denominator
266
 
267
  # Expert-level computation auxiliary loss
268
- aux_loss = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
269
  return topk_idx, topk_weight.to(hidden_states.dtype), aux_loss
270
 
271
 
@@ -314,11 +337,11 @@ class DeepseekMoE(nn.Module):
314
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
315
  flat_topk_idx = topk_idx.view(-1)
316
  if self.training:
317
- y = self.moe_train(hidden_states, flat_topk_idx, topk_weight.view(-1, 1))
318
  y = y.view(*orig_shape)
319
  y = AddAuxiliaryLoss.apply(y, aux_loss)
320
  else:
321
- y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight.view(-1, 1)).view(*orig_shape)
322
  if self.config.n_shared_experts is not None:
323
  y = y + self.shared_experts(identity)
324
  return y
@@ -329,7 +352,7 @@ class DeepseekMoE(nn.Module):
329
  for i, expert in enumerate(self.experts):
330
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
331
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
332
- return y
333
 
334
  @torch.no_grad()
335
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):
 
265
  topk_weight = topk_weight / denominator
266
 
267
  # Expert-level computation auxiliary loss
268
+ # (was absent before)
269
+ if self.training and self.alpha > 0.0:
270
+ scores_for_aux = scores
271
+ aux_topk = self.top_k
272
+ # always compute aux loss based on the naive greedy topk method
273
+ topk_idx_for_aux_loss = topk_idx.view(bsz, -1)
274
+ if self.seq_aux:
275
+ scores_for_seq_aux = scores_for_aux.view(bsz, seq_len, -1)
276
+ ce = torch.zeros(bsz, self.n_routed_experts, device=hidden_states.device, dtype=torch.float32)
277
+ ce.scatter_add_(
278
+ 1,
279
+ topk_idx_for_aux_loss,
280
+ torch.ones(bsz, seq_len * aux_topk, device=hidden_states.device, dtype=torch.float32)
281
+ )
282
+ ce.div_(seq_len * aux_topk / self.n_routed_experts)
283
+ aux_loss = (ce * scores_for_seq_aux.mean(dim=1)).sum(dim=1).mean() * self.alpha
284
+ else:
285
+ mask_ce = F.one_hot(topk_idx_for_aux_loss.view(-1), num_classes=self.n_routed_experts)
286
+ ce = mask_ce.float().mean(0)
287
+ Pi = scores_for_aux.mean(0)
288
+ fi = ce * self.n_routed_experts
289
+ aux_loss = (Pi * fi).sum() * self.alpha
290
+ else:
291
+ aux_loss = None
292
  return topk_idx, topk_weight.to(hidden_states.dtype), aux_loss
293
 
294
 
 
337
  hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
338
  flat_topk_idx = topk_idx.view(-1)
339
  if self.training:
340
+ y = self.moe_train(hidden_states, flat_topk_idx, topk_weight) # removed unnecessary .view(-1, 1)
341
  y = y.view(*orig_shape)
342
  y = AddAuxiliaryLoss.apply(y, aux_loss)
343
  else:
344
+ y = self.moe_infer(hidden_states, flat_topk_idx, topk_weight).view(*orig_shape) # removed unnecessary .view(-1, 1)
345
  if self.config.n_shared_experts is not None:
346
  y = y + self.shared_experts(identity)
347
  return y
 
352
  for i, expert in enumerate(self.experts):
353
  y[flat_topk_idx == i] = expert(hidden_states[flat_topk_idx == i])
354
  y = (y.view(*topk_weight.shape, -1) * topk_weight.unsqueeze(-1)).sum(dim=1)
355
+ return y.to(hidden_states.dtype) # .sum() in previous line returns fp32 tensor
356
 
357
  @torch.no_grad()
358
  def moe_infer(self, x, flat_expert_indices, flat_expert_weights):