autoprogrammer commited on
Commit
41d840a
·
verified ·
1 Parent(s): 278a1f9

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +32 -97
modeling_densebackward_olmoe0125.py CHANGED
@@ -26,118 +26,53 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
26
  def forward(self, hidden_states: torch.Tensor):
27
  batch_size, seq_length, hidden_dim = hidden_states.shape
28
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
 
29
 
 
30
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
31
- routing_weights = F.softmax(router_logits, dim=1, dtype=flat_hidden.dtype) # (B*seq_len, num_experts)
32
 
 
33
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
34
  if self.norm_topk_prob:
35
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
36
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
37
 
38
- # ---------- 稀疏计算部分 ----------
39
- sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim),
40
- dtype=flat_hidden.dtype, device=flat_hidden.device)
41
- # 使用 tensor 存储,每个 token 对各专家的输出:形状 (B*seq_len, num_experts, hidden_dim)
42
- activated_outputs_tensor = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
43
- dtype=flat_hidden.dtype, device=flat_hidden.device)
44
- expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
45
- expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
46
-
47
  for expert_idx in range(self.num_experts):
48
  expert_layer = self.experts[expert_idx]
49
- idx, top_x = torch.where(expert_mask[expert_idx])
50
- if top_x.numel() > 0:
51
- current_state = flat_hidden[top_x] # (n, hidden_dim)
52
- current_output = expert_layer(current_state) # (n, hidden_dim)
53
- weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
54
- weighted_output = current_output * weight
55
- sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
56
- # 直接存入 tensor:激活 token 对当前专家的输出
57
- activated_outputs_tensor[top_x, expert_idx, :] = current_output
58
- # ---------- 稀疏计算结束 ----------
59
-
60
- # ---------- Dense估计部分 (向量化版本,激活专家直接使用输出) ----------
61
- all_expert_outputs = activated_outputs_tensor # (B*seq_len, num_experts, hidden_dim)
62
- all_routing = selected_experts.tolist() # list,每个 token 的激活专家列表
63
 
64
- N_tokens = flat_hidden.size(0)
65
- num_experts = self.num_experts
66
-
67
- # 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
68
- R = F.one_hot(selected_experts, num_classes=num_experts).to(flat_hidden.dtype) # (N_tokens, top_k, num_experts)
69
- R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
70
-
71
- # 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
72
- S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
73
- S = S * (1 - torch.eye(N_tokens, device=S.device, dtype=flat_hidden.dtype)) # 去除自身
74
-
75
- # 构造候选 mask M: (N_tokens, N_tokens, num_experts)
76
- # M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
77
- R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
78
- S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
79
- candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).to(flat_hidden.dtype) # (N_tokens, N_tokens, num_experts)
80
-
81
- # 对于数值稳定,排除 token 自身(对角线置0)
82
- candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device, dtype=flat_hidden.dtype).unsqueeze(-1))
83
-
84
- # 扩展 mask 和 all_expert_outputs 以便批量聚合
85
- # all_expert_outputs: (N_tokens, num_experts, hidden_dim)
86
- candidate_mask_exp = candidate_mask.unsqueeze(-1) # (N_tokens, N_tokens, num_experts, 1)
87
- all_expert_outputs_exp = all_expert_outputs.unsqueeze(0) # (1, N_tokens, num_experts, hidden_dim)
88
-
89
- # 对每个 token i 和专家 e,聚合候选 token 的输出
90
- sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
91
- count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
92
- estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / (count_outputs+1),
93
- torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
94
-
95
- # 对于激活的专家,直接使用当前 token 的输出
96
- # R > 0 表示激活,扩展为 (N_tokens, num_experts, 1) 与 activated_outputs_tensor 对齐
97
- activated_mask = (R > 0).unsqueeze(-1)
98
- estimated_dense_all = torch.where(activated_mask, activated_outputs_tensor, estimated_dense_all)
99
-
100
- # 利用 gate_prob 加权聚合所有专家输出
101
- gate_prob_exp = routing_weights.to(estimated_dense_all.dtype).unsqueeze(-1) # (N_tokens, num_experts, 1)
102
- dense_outputs = (gate_prob_exp * estimated_dense_all).sum(dim=1) # (N_tokens, hidden_dim)
103
- # ---------- Dense估计结束 (向量化版本) ----------
104
-
105
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
106
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
 
107
  return final_output, router_logits
108
 
109
- def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
110
- num_experts = gate_prob.size(0)
111
- dense_parts = {}
112
- # 对于激活的专家,直接使用 tensor 的对应行
113
- for idx in activated:
114
- dense_parts[idx] = activated_outputs[idx]
115
- non_activated = [i for i in range(num_experts) if i not in activated]
116
- for i in non_activated:
117
- indices = []
118
- for idx, r_dec in enumerate(all_routing):
119
- if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
120
- indices.append(idx)
121
- if indices:
122
- selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
123
- mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
124
- if mask.sum() > 0:
125
- estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
126
- else:
127
- estimated = torch.zeros_like(selected_outputs[0])
128
- else:
129
- all_outputs = all_expert_outputs[:, i, :]
130
- mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
131
- if mask.sum() > 0:
132
- estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
133
- else:
134
- estimated = torch.zeros_like(all_outputs[0])
135
- dense_parts[i] = estimated
136
- estimated_dense = 0
137
- for i in range(num_experts):
138
- estimated_dense += gate_prob[i] * dense_parts[i]
139
- return estimated_dense
140
-
141
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
142
  """
143
  自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
 
26
  def forward(self, hidden_states: torch.Tensor):
27
  batch_size, seq_length, hidden_dim = hidden_states.shape
28
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
29
+ N_tokens = flat_hidden.size(0)
30
 
31
+ # 计算路由逻辑
32
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
33
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
34
 
35
+ # 选择top-k专家
36
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
37
  if self.norm_topk_prob:
38
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
39
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
40
 
41
+ # ---------- 真实计算所有专家输出(密集计算)----------
42
+ all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),
43
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
44
+
 
 
 
 
 
45
  for expert_idx in range(self.num_experts):
46
  expert_layer = self.experts[expert_idx]
47
+ # 对所有token都计算当前专家的输出
48
+ expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
49
+ all_expert_outputs[:, expert_idx, :] = expert_output
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # ---------- 提取激活专家输出(稀疏前向)----------
52
+ # 计算稀疏输出
53
+ sparse_output = torch.zeros((N_tokens, hidden_dim),
54
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
55
+
56
+ # 为每个token,提取并加权其激活专家的输出
57
+ for token_idx in range(N_tokens):
58
+ for k in range(self.top_k):
59
+ expert_idx = selected_experts[token_idx, k].item()
60
+ weight = routing_weights_topk[token_idx, k]
61
+ sparse_output[token_idx] += all_expert_outputs[token_idx, expert_idx] * weight
62
+
63
+ # ---------- 密集计算聚合(用于反向传播)----------
64
+ # 使用所有专家的输出和路由权重计算密集输出
65
+ routing_weights_expanded = routing_weights.unsqueeze(-1) # (N_tokens, num_experts, 1)
66
+ dense_outputs = (all_expert_outputs * routing_weights_expanded).sum(dim=1) # (N_tokens, hidden_dim)
67
+
68
+ # ---------- 组合稀疏前向和密集反向 ----------
69
+ # sparse_output.detach()保留稀疏前向计算图
70
+ # (dense_outputs - dense_outputs.detach())只保留密集反向梯度
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
72
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
73
+
74
  return final_output, router_logits
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
77
  """
78
  自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,