autoprogrammer commited on
Commit
34dd82a
·
verified ·
1 Parent(s): f23ed8f

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +56 -77
modeling_densebackward_olmoe0125.py CHANGED
@@ -10,6 +10,7 @@ from .configuration_densebackward_olmoe0125 import DenseBackwardOLMoEConfig
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
 
13
  """
14
  继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
15
  前向输出依旧保持与官方相同(即稀疏计算结果),
@@ -23,48 +24,27 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
23
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
24
  """
25
  def forward(self, hidden_states: torch.Tensor):
26
- """
27
- 输入:
28
- hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
29
- 输出:
30
- final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
31
- router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
32
- 实现思路:
33
- 1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
34
- 并计算全专家的 routing 权重(softmax 后)。
35
- 2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
36
- 如配置要求,归一化 top-k 概率。
37
- 3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
38
- 并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
39
- 4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
40
- 再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
41
- 5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
42
- 6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
43
- """
44
- #determine the shape of hidden_states
45
  batch_size, seq_length, hidden_dim = hidden_states.shape
46
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
47
 
48
- # 计算路由 logits 和全专家 routing 权重
49
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
50
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
51
 
52
- # Top-k 选择
53
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
54
  if self.norm_topk_prob:
55
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
56
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
57
 
58
  # ---------- 稀疏计算部分 ----------
59
- # 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
60
- sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
61
- # 用于记录每个 token 对激活专家的实际输出
62
- activated_outputs = [{} for _ in range(flat_hidden.size(0))]
63
- # one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
64
- expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
65
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
66
 
67
- for expert_idx in range(self.num_experts):
68
  expert_layer = self.experts[expert_idx]
69
  idx, top_x = torch.where(expert_mask[expert_idx])
70
  if top_x.numel() > 0:
@@ -73,75 +53,77 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
73
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
74
  weighted_output = current_output * weight
75
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
76
- # 保存当前 token 对该专家的实际输出
77
- for pos, token_idx in enumerate(top_x.tolist()):
78
- activated_outputs[token_idx][expert_idx] = current_output[pos]
79
  # ---------- 稀疏计算结束 ----------
80
 
81
- # ---------- Dense估计部分 ----------
82
- # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
- # 创建全零张量,只填入已激活专家的输��
84
- all_expert_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
85
- dtype=flat_hidden.dtype, device=flat_hidden.device)
86
- # 填入已激活专家的输出
87
- for i in range(flat_hidden.size(0)):
88
- for expert_idx in activated_outputs[i].keys():
89
- all_expert_outputs[i, expert_idx] = activated_outputs[i][expert_idx]
90
- # 将 selected_experts 转换为 list,每个 token 的激活专家列表
91
- all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
92
 
93
- dense_outputs = []
94
- for i in range(flat_hidden.size(0)):
95
- dense_est = self.estimate_dense_output(
96
- token_idx=i,
97
- activated=all_routing[i], # 当前 token 激活的专家列表,例如 [a, b]
98
- gate_prob=routing_weights[i], # 当前 token 的完整 routing 权重 (num_experts,)
99
- activated_outputs=activated_outputs[i], # 当前 token 对激活专家的实际输出
100
- all_routing=all_routing, # batch 每个 token 的激活专家列表(list of lists)
101
- all_expert_outputs=all_expert_outputs # (B*seq_len, num_experts, hidden_dim)
102
- )
103
- dense_outputs.append(dense_est.unsqueeze(0))
104
- dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
105
- # ---------- Dense估计结束 ----------
106
-
107
- # 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
108
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
109
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
110
  return final_output, router_logits
111
 
112
  def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
113
- """
114
- 对于当前 token,根��� mini-batch 中的信息估计 dense 输出。
115
- 参数:
116
- token_idx: 当前 token 的索引(标量)
117
- activated: 当前 token 激活的专家列表,例如 [1, 3]
118
- gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
119
- activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
120
- all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
121
- all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
122
- 返回:
123
- estimated_dense: Tensor, (hidden_dim,)
124
- """
125
  num_experts = gate_prob.size(0)
126
  dense_parts = {}
127
- # 对于激活的专家,直接使用其实际输出
128
  for idx in activated:
129
  dense_parts[idx] = activated_outputs[idx]
130
- # 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
131
  non_activated = [i for i in range(num_experts) if i not in activated]
132
- for i in non_activated:
133
  indices = []
134
  for idx, r_dec in enumerate(all_routing):
135
  if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
136
  indices.append(idx)
137
  if indices:
138
  selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
139
- # 只计算非零值的平均值
140
  mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
141
  if mask.sum() > 0:
142
  estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
143
  else:
144
- # 如果全是零,返回零向量
145
  estimated = torch.zeros_like(selected_outputs[0])
146
  else:
147
  all_outputs = all_expert_outputs[:, i, :]
@@ -149,16 +131,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
149
  if mask.sum() > 0:
150
  estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
151
  else:
152
- # 如果全是零,返回零向量
153
  estimated = torch.zeros_like(all_outputs[0])
154
  dense_parts[i] = estimated
155
- # 按 gate_prob 加权求和各专家输出
156
  estimated_dense = 0
157
  for i in range(num_experts):
158
  estimated_dense += gate_prob[i] * dense_parts[i]
159
  return estimated_dense
160
 
161
-
162
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
163
  """
164
  自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
 
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
13
+
14
  """
15
  继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
16
  前向输出依旧保持与官方相同(即稀疏计算结果),
 
24
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
25
  """
26
  def forward(self, hidden_states: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  batch_size, seq_length, hidden_dim = hidden_states.shape
28
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
29
 
 
30
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
31
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
32
 
 
33
  routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
34
  if self.norm_topk_prob:
35
  routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
36
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
37
 
38
  # ---------- 稀疏计算部分 ----------
39
+ sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim),
40
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
41
+ # 使用 tensor 存储,每个 token 对各专家的输出:形状 (B*seq_len, num_experts, hidden_dim)
42
+ activated_outputs_tensor = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
43
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
44
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
45
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
46
 
47
+ for expert_idx in tqdm(range(self.num_experts), desc="修改版本-专家循环"):
48
  expert_layer = self.experts[expert_idx]
49
  idx, top_x = torch.where(expert_mask[expert_idx])
50
  if top_x.numel() > 0:
 
53
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
54
  weighted_output = current_output * weight
55
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
56
+ # 直接存入 tensor:激活 token 对当前专家的输出
57
+ activated_outputs_tensor[top_x, expert_idx, :] = current_output
 
58
  # ---------- 稀疏计算结束 ----------
59
 
60
+ # ---------- Dense估计部分 (向量化版本,激活专家直接使用输出) ----------
61
+ all_expert_outputs = activated_outputs_tensor # (B*seq_len, num_experts, hidden_dim)
62
+ all_routing = selected_experts.tolist() # list,每个 token 的激活专家列表
 
 
 
 
 
 
 
 
63
 
64
+ N_tokens = flat_hidden.size(0)
65
+ num_experts = self.num_experts
66
+
67
+ # 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
68
+ R = F.one_hot(selected_experts, num_classes=num_experts).float() # (N_tokens, top_k, num_experts)
69
+ R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
70
+
71
+ # 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
72
+ S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
73
+ S = S * (1 - torch.eye(N_tokens, device=S.device)) # 去除自身
74
+
75
+ # 构造候选 mask M: (N_tokens, N_tokens, num_experts)
76
+ # M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
77
+ R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
78
+ S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
79
+ candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).float() # (N_tokens, N_tokens, num_experts)
80
+
81
+ # 对于数值稳定,排除 token 自身(对角线置0)
82
+ candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device).unsqueeze(-1))
83
+
84
+ # 扩展 mask 和 all_expert_outputs 以便批量聚合
85
+ # all_expert_outputs: (N_tokens, num_experts, hidden_dim)
86
+ candidate_mask_exp = candidate_mask.unsqueeze(-1) # (N_tokens, N_tokens, num_experts, 1)
87
+ all_expert_outputs_exp = all_expert_outputs.unsqueeze(0) # (1, N_tokens, num_experts, hidden_dim)
88
+
89
+ # 对每个 token i 和专家 e,聚合候选 token 的输出
90
+ sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
91
+ count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
92
+ estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / count_outputs,
93
+ torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
94
+
95
+ # 对于激活的专家,直接使用当前 token 的输出
96
+ # R > 0 表示激活,扩展为 (N_tokens, num_experts, 1) 与 activated_outputs_tensor 对齐
97
+ activated_mask = (R > 0).unsqueeze(-1)
98
+ estimated_dense_all = torch.where(activated_mask, activated_outputs_tensor, estimated_dense_all)
99
+
100
+ # 利用 gate_prob 加权聚合所有专家输出
101
+ gate_prob_exp = routing_weights.to(estimated_dense_all.dtype).unsqueeze(-1) # (N_tokens, num_experts, 1)
102
+ dense_outputs = (gate_prob_exp * estimated_dense_all).sum(dim=1) # (N_tokens, hidden_dim)
103
+ # ---------- Dense估计结束 (向量化版本) ----------
104
+
105
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
106
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
107
  return final_output, router_logits
108
 
109
  def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
 
 
 
 
 
 
 
 
 
 
 
 
110
  num_experts = gate_prob.size(0)
111
  dense_parts = {}
112
+ # 对于激活的专家,直接使用 tensor 的对应行
113
  for idx in activated:
114
  dense_parts[idx] = activated_outputs[idx]
 
115
  non_activated = [i for i in range(num_experts) if i not in activated]
116
+ for i in tqdm(non_activated, desc=f"修改版本-Token {token_idx} 非激活专家估计"):
117
  indices = []
118
  for idx, r_dec in enumerate(all_routing):
119
  if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
120
  indices.append(idx)
121
  if indices:
122
  selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
 
123
  mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
124
  if mask.sum() > 0:
125
  estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
126
  else:
 
127
  estimated = torch.zeros_like(selected_outputs[0])
128
  else:
129
  all_outputs = all_expert_outputs[:, i, :]
 
131
  if mask.sum() > 0:
132
  estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
133
  else:
 
134
  estimated = torch.zeros_like(all_outputs[0])
135
  dense_parts[i] = estimated
 
136
  estimated_dense = 0
137
  for i in range(num_experts):
138
  estimated_dense += gate_prob[i] * dense_parts[i]
139
  return estimated_dense
140
 
 
141
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
142
  """
143
  自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,