autoprogrammer commited on
Commit
c401025
·
verified ·
1 Parent(s): 41d840a

Update modeling_densebackward_olmoe0125.py

Browse files
modeling_densebackward_olmoe0125.py CHANGED
@@ -30,13 +30,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
30
 
31
  # 计算路由逻辑
32
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
33
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
34
 
35
  # 选择top-k专家
36
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
37
  if self.norm_topk_prob:
38
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
39
- routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
40
 
41
  # ---------- 真实计算所有专家输出(密集计算)----------
42
  all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),
 
30
 
31
  # 计算路由逻辑
32
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
33
+ routing_weights = F.softmax(router_logits, dim=1, dtype=flat_hidden.dtype) # (B*seq_len, num_experts)
34
 
35
  # 选择top-k专家
36
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
37
  if self.norm_topk_prob:
38
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
39
+ # 这里不需要转换类型,因为routing_weights已经使用了flat_hidden.dtype
40
 
41
  # ---------- 真实计算所有专家输出(密集计算)----------
42
  all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),