autoprogrammer commited on
Commit
7385e74
·
verified ·
1 Parent(s): 39520a1

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +163 -130
modeling_densebackward_olmoe0125.py CHANGED
@@ -23,140 +23,173 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
23
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
24
  """
25
  def forward(self, hidden_states: torch.Tensor):
26
- """
27
- 输入:
28
- hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
29
- 输出:
30
- final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
31
- router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
32
- 实现思路:
33
- 1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
34
- 并计算全专家的 routing 权重(softmax 后)。
35
- 2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
36
- 如配置要求,归一化 top-k 概率。
37
- 3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
38
- 并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
39
- 4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
40
- 再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
41
- 5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
42
- 6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
43
- """
44
- #determine the shape of hidden_states
45
- batch_size, seq_length, hidden_dim = hidden_states.shape
46
- flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
47
-
48
- # 计算路由 logits 和全专家 routing 权重
49
- router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
50
- routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
51
-
52
- # Top-k 选择
53
- routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
54
- if self.norm_topk_prob:
55
- routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
56
- routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
57
-
58
- # ---------- 稀疏计算部分 ----------
59
- # 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
60
- sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
61
- # 用于记录每个 token 对激活专家的实际输出
62
- activated_outputs = [{} for _ in range(flat_hidden.size(0))]
63
- # one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
64
- expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
65
- expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
66
-
67
- for expert_idx in range(self.num_experts):
68
- expert_layer = self.experts[expert_idx]
69
- idx, top_x = torch.where(expert_mask[expert_idx])
70
- if top_x.numel() > 0:
71
- current_state = flat_hidden[top_x] # (n, hidden_dim)
72
- current_output = expert_layer(current_state) # (n, hidden_dim)
73
- weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
74
- weighted_output = current_output * weight
75
- sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
76
- # 保存当前 token 对该专家的实际输出
77
- for pos, token_idx in enumerate(top_x.tolist()):
78
- activated_outputs[token_idx][expert_idx] = current_output[pos]
79
- # ---------- 稀疏计算结束 ----------
80
-
81
- # ---------- Dense估计部分 ----------
82
- # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
- # 创建全零张量,只填入已激活专家的输出
84
- all_expert_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
85
  dtype=flat_hidden.dtype, device=flat_hidden.device)
86
- # 填入已激活专家的输出
87
- for i in range(flat_hidden.size(0)):
88
- for expert_idx in activated_outputs[i].keys():
89
- all_expert_outputs[i, expert_idx] = activated_outputs[i][expert_idx]
90
- # 将 selected_experts 转换为 list,每个 token 的激活专家列表
91
- all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
 
93
- dense_outputs = []
94
- for i in range(flat_hidden.size(0)):
95
- dense_est = self.estimate_dense_output(
96
- token_idx=i,
97
- activated=all_routing[i], # 当前 token 激活的专家列表,例如 [a, b]
98
- gate_prob=routing_weights[i], # 当前 token 的完整 routing 权重 (num_experts,)
99
- activated_outputs=activated_outputs[i], # 当前 token 对激活专家的实际输出
100
- all_routing=all_routing, # 全 batch 每个 token 的激活专家列表(list of lists)
101
- all_expert_outputs=all_expert_outputs # (B*seq_len, num_experts, hidden_dim)
102
- )
103
- dense_outputs.append(dense_est.unsqueeze(0))
104
- dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
105
- # ---------- Dense估计结束 ----------
106
-
107
- # 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
108
- final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
109
- final_output = final_flat.view(batch_size, seq_length, hidden_dim)
110
- return final_output, router_logits
111
-
112
- def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
113
- """
114
- 对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
115
- 参数:
116
- token_idx: 当前 token 的索引(标量)
117
- activated: 当前 token 激活的专家列表,例如 [1, 3]
118
- gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
119
- activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
120
- all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
121
- all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
122
- 返回:
123
- estimated_dense: Tensor, (hidden_dim,)
124
- """
125
- num_experts = gate_prob.size(0)
126
- dense_parts = {}
127
- # 对于激活的专家,直接使用其实际输出
128
- for idx in activated:
129
- dense_parts[idx] = activated_outputs[idx]
130
- # 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
131
- non_activated = [i for i in range(num_experts) if i not in activated]
132
- for i in non_activated:
133
- indices = []
134
- for idx, r_dec in enumerate(all_routing):
135
- if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
136
- indices.append(idx)
137
- if indices:
138
- selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
139
- # 只计算非零值的平均值
140
- mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
141
- if mask.sum() > 0:
142
- estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
143
- else:
144
- # 如果全是零,返回零向量
145
- estimated = torch.zeros_like(selected_outputs[0])
146
  else:
147
- all_outputs = all_expert_outputs[:, i, :]
148
- mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
149
- if mask.sum() > 0:
150
- estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
  else:
152
- # 如果全是零,返回零向量
153
- estimated = torch.zeros_like(all_outputs[0])
154
- dense_parts[i] = estimated
155
- # gate_prob 加权求和各专家输出
156
- estimated_dense = 0
157
- for i in range(num_experts):
158
- estimated_dense += gate_prob[i] * dense_parts[i]
159
- return estimated_dense
 
 
 
 
 
 
 
160
 
161
 
162
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
 
23
  router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
24
  """
25
  def forward(self, hidden_states: torch.Tensor):
26
+ # determine the shape of hidden_states
27
+ batch_size, seq_length, hidden_dim = hidden_states.shape
28
+ flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
29
+ total_tokens = flat_hidden.size(0)
30
+
31
+ # 计算路由 logits 和全专家 routing 权重
32
+ router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
33
+ routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
34
+
35
+ # Top-k 选择
36
+ routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
37
+ if self.norm_topk_prob:
38
+ routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
39
+ routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
40
+
41
+ # ---------- 稀疏计算部分 ----------
42
+ # 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
43
+ sparse_output = torch.zeros((total_tokens, hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
44
+
45
+ # 创建一个张量存储激活专家的输出,避免使用Python字典
46
+ # shape: (B*seq_len, num_experts, hidden_dim)
47
+ all_expert_outputs = torch.zeros((total_tokens, self.num_experts, hidden_dim),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
  dtype=flat_hidden.dtype, device=flat_hidden.device)
49
+
50
+ # 使用张量掩码跟踪哪些专家被激活
51
+ # shape: (B*seq_len, num_experts)
52
+ expert_activated = torch.zeros((total_tokens, self.num_experts),
53
+ dtype=torch.bool, device=flat_hidden.device)
54
+
55
+ # one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
56
+ expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
57
+ expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
58
+
59
+ for expert_idx in range(self.num_experts):
60
+ expert_layer = self.experts[expert_idx]
61
+ idx, top_x = torch.where(expert_mask[expert_idx])
62
+ if top_x.numel() > 0:
63
+ current_state = flat_hidden[top_x] # (n, hidden_dim)
64
+ current_output = expert_layer(current_state) # (n, hidden_dim)
65
+ weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
66
+ weighted_output = current_output * weight
67
+ sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
68
+
69
+ # 保存专家输出到张量中,而不是使用字典
70
+ all_expert_outputs.index_copy_(0, top_x,
71
+ torch.zeros_like(all_expert_outputs[0:top_x.size(0)]).scatter_(
72
+ 1, expert_idx * torch.ones((top_x.size(0), 1),
73
+ dtype=torch.long,
74
+ device=flat_hidden.device),
75
+ current_output.unsqueeze(1)))
76
+
77
+ # 标记哪些专家被激活
78
+ expert_activated.index_copy_(0, top_x,
79
+ torch.zeros_like(expert_activated[0:top_x.size(0)]).scatter_(
80
+ 1, expert_idx * torch.ones((top_x.size(0), 1),
81
+ dtype=torch.long,
82
+ device=flat_hidden.device),
83
+ torch.ones((top_x.size(0), 1),
84
+ dtype=torch.bool,
85
+ device=flat_hidden.device)))
86
+ # ---------- 稀疏计算结束 ----------
87
+
88
+ # ---------- Dense估计部分 ----------
89
+ # 从GPU获取必要信息,避免过多的tensor->list转换
90
+ selected_experts_gpu = selected_experts # 保持在GPU上
91
+
92
+ # 预分配结果张量,避免在循环中append
93
+ dense_outputs = torch.zeros_like(sparse_output)
94
+
95
+ # 使用向量化的estimate_dense_output函数
96
+ dense_outputs = self.estimate_dense_output_batch(
97
+ total_tokens=total_tokens,
98
+ selected_experts=selected_experts_gpu,
99
+ routing_weights=routing_weights,
100
+ expert_activated=expert_activated,
101
+ all_expert_outputs=all_expert_outputs
102
+ )
103
+ # ---------- Dense估计结束 ----------
104
+
105
+ # 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
106
+ final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
107
+ final_output = final_flat.view(batch_size, seq_length, hidden_dim)
108
+ return final_output, router_logits
109
+
110
+ def estimate_dense_output_batch(self, total_tokens, selected_experts, routing_weights,
111
+ expert_activated, all_expert_outputs):
112
+ """
113
+ 批量估计所有token的dense输出,优化版本。
114
+
115
+ 参数:
116
+ total_tokens: token总数
117
+ selected_experts: 每个token激活的专家索引,形状 (total_tokens, top_k)
118
+ routing_weights: 路由权重,形状 (total_tokens, num_experts)
119
+ expert_activated: 掩码张量,标记每个token激活了哪些专家,形状 (total_tokens, num_experts)
120
+ all_expert_outputs: 专家输出,形状 (total_tokens, num_experts, hidden_dim)
121
 
122
+ 返回:
123
+ dense_outputs: 形状 (total_tokens, hidden_dim)
124
+ """
125
+ hidden_dim = all_expert_outputs.size(-1)
126
+ num_experts = routing_weights.size(1)
127
+ device = all_expert_outputs.device
128
+
129
+ # 预分配结果张量
130
+ dense_outputs = torch.zeros((total_tokens, hidden_dim), dtype=all_expert_outputs.dtype, device=device)
131
+
132
+ # 对每个token单独处理(此处仍需循环,但后续可进一步优化)
133
+ for token_idx in range(total_tokens):
134
+ # 对于激活的专家,直接使用输出
135
+ activated_mask = expert_activated[token_idx] # (num_experts,)
136
+
137
+ # 对于未激活的专家,找到估计值
138
+ for expert_idx in range(num_experts):
139
+ if activated_mask[expert_idx]:
140
+ # 直接使用激活专家的输出
141
+ expert_output = all_expert_outputs[token_idx, expert_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
142
  else:
143
+ # 寻找可以用于估计的输出
144
+ # 找出其他激活了当前专家的token
145
+ tokens_with_expert = expert_activated[:, expert_idx]
146
+
147
+ # 找出同时激活了当前token的某些专家和当前专家的其他token
148
+ # 首先获取当前token激活的专家
149
+ current_activated = selected_experts[token_idx]
150
+
151
+ # 在其他token中寻找同时激活了current_activated中专家和expert_idx的token
152
+ valid_tokens = torch.zeros(total_tokens, dtype=torch.bool, device=device)
153
+
154
+ # 对于每个其他token,检查它是否同时激活了当前token的某个专家和当前专家
155
+ for other_token in range(total_tokens):
156
+ if other_token == token_idx:
157
+ continue
158
+
159
+ # 检查其他token是否激活了当前专家
160
+ if expert_activated[other_token, expert_idx]:
161
+ # 检查是否有共同激活的专家
162
+ other_experts = selected_experts[other_token]
163
+ common = torch.any(torch.isin(other_experts, current_activated))
164
+ if common:
165
+ valid_tokens[other_token] = True
166
+
167
+ # 如果找到了有效token
168
+ if valid_tokens.any():
169
+ # 获取有效token对当前专家的输出
170
+ valid_outputs = all_expert_outputs[valid_tokens, expert_idx]
171
+ # 只计算非零值的平均值
172
+ mask = (valid_outputs.sum(dim=-1) != 0).to(valid_outputs.dtype).unsqueeze(-1)
173
+ if mask.sum() > 0:
174
+ expert_output = (valid_outputs * mask).sum(dim=0) / mask.sum()
175
+ else:
176
+ expert_output = torch.zeros(hidden_dim, dtype=all_expert_outputs.dtype, device=device)
177
  else:
178
+ # 如果没有找到有效token,使用所有激活了当前专家的token的输出
179
+ if tokens_with_expert.any():
180
+ all_valid_outputs = all_expert_outputs[tokens_with_expert, expert_idx]
181
+ mask = (all_valid_outputs.sum(dim=-1) != 0).to(all_valid_outputs.dtype).unsqueeze(-1)
182
+ if mask.sum() > 0:
183
+ expert_output = (all_valid_outputs * mask).sum(dim=0) / mask.sum()
184
+ else:
185
+ expert_output = torch.zeros(hidden_dim, dtype=all_expert_outputs.dtype, device=device)
186
+ else:
187
+ expert_output = torch.zeros(hidden_dim, dtype=all_expert_outputs.dtype, device=device)
188
+
189
+ # 根据routing权重加权
190
+ dense_outputs[token_idx] += routing_weights[token_idx, expert_idx] * expert_output
191
+
192
+ return dense_outputs
193
 
194
 
195
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):