autoprogrammer commited on
Commit
9fdbe51
·
verified ·
1 Parent(s): 9d56eed

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +105 -141
modeling_densebackward_olmoe0125.py CHANGED
@@ -6,15 +6,46 @@ import torch.nn.functional as F
6
 
7
  # 导入官方实现(注意根据你的 transformers 版本调整导入路径)
8
  from transformers.models.olmoe.modeling_olmoe import OlmoeForCausalLM, OlmoeSparseMoeBlock, OlmoeMLP
9
- from .configuration_densebackward_olmoe0125 import DenseBackwardOLMoEConfig
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
 
 
 
 
 
 
 
 
 
 
 
 
13
  def forward(self, hidden_states: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  batch_size, seq_length, hidden_dim = hidden_states.shape
15
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
16
 
17
- # 计算路由 logits 和路由权重
18
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
19
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
20
 
@@ -25,130 +56,90 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
25
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
26
 
27
  # ---------- 稀疏计算部分 ----------
28
- # 初始化稀疏输出
29
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
30
-
31
- # 记录每个专家被激活的token
32
- token_indices_per_expert = [[] for _ in range(self.num_experts)]
33
- expert_outputs_dict = {i: {} for i in range(self.num_experts)}
34
-
35
- # one-hot 编码 top-k 专家
36
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
37
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
38
 
39
- # 保存每个token激活的专家列表
40
- token_to_experts = [[] for _ in range(flat_hidden.size(0))]
41
-
42
- # 稀疏计算
43
  for expert_idx in range(self.num_experts):
44
  expert_layer = self.experts[expert_idx]
45
  idx, top_x = torch.where(expert_mask[expert_idx])
46
-
47
  if top_x.numel() > 0:
48
- # 记录该专家被哪些token激活
49
- top_x_list = top_x.tolist()
50
- token_indices_per_expert[expert_idx].extend(top_x_list)
51
-
52
- # 记录每个token激活了哪些专家
53
- for token_idx in top_x_list:
54
- token_to_experts[token_idx].append(expert_idx)
55
-
56
- # 标准MoE前向计算
57
  current_state = flat_hidden[top_x] # (n, hidden_dim)
58
  current_output = expert_layer(current_state) # (n, hidden_dim)
59
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
60
  weighted_output = current_output * weight
61
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
62
-
63
- # 保存未加权的专家输出,用于后续估计
64
- for i, token_idx in enumerate(top_x_list):
65
- expert_outputs_dict[expert_idx][token_idx] = current_output[i]
66
-
67
- # ---------- 初始化密集估计输出 ----------
68
- # 创建一个每个token对所有专家输出的张量
69
- dense_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
70
- dtype=flat_hidden.dtype, device=flat_hidden.device)
71
-
72
- # 首先将已计算的专家输出填入
73
- for expert_idx in range(self.num_experts):
74
- for token_idx, output in expert_outputs_dict[expert_idx].items():
75
- dense_outputs[token_idx, expert_idx] = output
76
-
77
- # ---------- 添加伪梯度路径 ----------
78
- # 对于每个未激活的专家,创建一个小的梯度路径
79
- # 首先创建一个全零的mask标记哪些(token,expert)对已被计算
80
- computed_mask = torch.zeros((flat_hidden.size(0), self.num_experts),
81
- dtype=torch.bool, device=flat_hidden.device)
82
-
83
- # 标记已计算的(token,expert)
84
- for expert_idx in range(self.num_experts):
85
- for token_idx in token_indices_per_expert[expert_idx]:
86
- computed_mask[token_idx, expert_idx] = True
87
-
88
- # 对未计算的(token,expert)对,添加微小的伪梯度路径
89
- # 我们使用一个小的缩放参数来确保前向计算几乎不受影响
90
- scale = 1e-4
91
-
92
- # 计算一个残差连接,确保梯度能够流向所有专家
93
- for token_idx in range(flat_hidden.size(0)):
94
- token_input = flat_hidden[token_idx:token_idx+1] # 保持维度
95
-
96
- # 对每个未激活的专家创建伪梯度路径
97
- for expert_idx in range(self.num_experts):
98
- if not computed_mask[token_idx, expert_idx]:
99
- # 查找是否有token激活了这个专家
100
- if token_indices_per_expert[expert_idx]:
101
- # 从激活该专家的token中选择一个
102
- # 优先选择与当前token共享其他专家的token
103
- similar_tokens = []
104
- for other_idx in token_indices_per_expert[expert_idx]:
105
- # 检查是否有共同激活的专家
106
- common_experts = set(token_to_experts[token_idx]) & set(token_to_experts[other_idx])
107
- if common_experts:
108
- similar_tokens.append(other_idx)
109
-
110
- if similar_tokens:
111
- # 使用相似token的平均输出
112
- similar_outputs = [expert_outputs_dict[expert_idx][t] for t in similar_tokens]
113
- estimated_output = torch.stack(similar_outputs).mean(0)
114
- else:
115
- # 使用所有激活该专家的token的平均输出
116
- all_outputs = [expert_outputs_dict[expert_idx][t] for t in token_indices_per_expert[expert_idx]]
117
- estimated_output = torch.stack(all_outputs).mean(0)
118
-
119
- # 添加微小的直接计算以维持梯度流
120
- direct_output = self.experts[expert_idx](token_input).squeeze(0)
121
-
122
- # 组合估计输出和直接计算
123
- # 前向使用估计输出,反向使用直接计算
124
- combined = estimated_output.detach() + scale * (direct_output - direct_output.detach())
125
- dense_outputs[token_idx, expert_idx] = combined
126
- else:
127
- # 如果没有token激活该专家,直接进行计算但使用小缩放
128
- # 这保证了梯度流而对前向几乎无影响
129
- direct_output = scale * self.experts[expert_idx](token_input).squeeze(0)
130
- dense_outputs[token_idx, expert_idx] = direct_output
131
-
132
- # ---------- 组合输出 ----------
133
- # 使用路由权重对每个token的所有专家输出进行加权
134
- dense_combined = torch.zeros_like(sparse_output)
135
-
136
- for token_idx in range(flat_hidden.size(0)):
137
- # 对该token的所有专家输出进行加权
138
- token_experts_output = dense_outputs[token_idx] # (num_experts, hidden_dim)
139
- token_routing_weights = routing_weights[token_idx].unsqueeze(-1) # (num_experts, 1)
140
- weighted_experts = token_experts_output * token_routing_weights
141
- token_output = weighted_experts.sum(0) # (hidden_dim)
142
- dense_combined[token_idx] = token_output
143
-
144
- # 使用直通梯度技巧:前向使用sparse_output,反向使用dense_combined
145
- # 增大直通梯度系数以增强梯度流
146
- straight_through_scale = 1.0 # 可以尝试不同的值
147
- final_flat = sparse_output.detach() + straight_through_scale * (dense_combined - dense_combined.detach())
148
-
149
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
150
  return final_output, router_logits
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
154
  """
@@ -161,48 +152,21 @@ class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
161
  base_model_prefix = "olmoe"
162
 
163
  def __init__(self, config):
164
- # 首先调用父类初始化方法
165
  super().__init__(config)
166
-
167
- # 不要尝试重新赋值self,而是从预训练模型加载并更新当前模型
168
- pretrained_model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0125")
169
-
170
- # 复制预训练模型的状态到当前模型
171
- self.config = pretrained_model.config
172
- self.model = pretrained_model.model
173
- self.vocab_size = pretrained_model.vocab_size
174
- self.router_aux_loss_coef = pretrained_model.router_aux_loss_coef
175
- self.num_experts = pretrained_model.num_experts
176
- self.lm_head = pretrained_model.lm_head
177
-
178
  # 遍历���型中所有 decoder 层,替换每个 OlmoeSparseMoeBlock 为 DenseBackward 版本
179
  # 此处假设官方模型在 self.model.layers 中组织 decoder 层,
180
  # 且每层中 mlp 模块包含属性 sparse_moe_block。
181
  for layer in self.model.layers:
182
- if hasattr(layer.mlp, "gate"):
183
- print("111")
184
- orig_block = layer.mlp
185
  # 通过直接复制原版属性创建新的块
186
  new_block = DenseBackwardOlmoeSparseMoeBlock(config) # 或其他适当参数
187
  # 然后手动复制需要共享的属性:
188
  new_block.gate = orig_block.gate
189
  new_block.experts = orig_block.experts
 
190
  new_block.num_experts = orig_block.num_experts
191
  new_block.top_k = orig_block.top_k
192
  new_block.norm_topk_prob = orig_block.norm_topk_prob
193
- layer.mlp = new_block
194
- print(type(layer.mlp))
195
 
196
- def main():
197
- config = DenseBackwardOLMoEConfig( # 官方模型参数
198
- model_marker="DenseBackward_olmoe_marker",
199
- )
200
- # 创建自定义模型实例
201
- model = DenseBackwardOLMoEForCausalLM(config)
202
- print(type(model))
203
- print(type(model.model))
204
- print(type(model.model.layers[0]))
205
- print(type(model.model.layers[0].mlp))
206
- print(type(model.model.layers[0].mlp.experts))
207
- if __name__ == "__main__":
208
- main()
 
6
 
7
  # 导入官方实现(注意根据你的 transformers 版本调整导入路径)
8
  from transformers.models.olmoe.modeling_olmoe import OlmoeForCausalLM, OlmoeSparseMoeBlock, OlmoeMLP
9
+ from configuration_custom import DenseBackwardOLMoEConfig
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
13
+ """
14
+ 继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
15
+ 前向输出依旧保持与官方相同(即稀疏计算结果),
16
+ 但在反向传播时,通过直通梯度让 dense 计算的梯度传递回来,
17
+ dense 输出通过对每个专家在所有 token 上进行计算,并利用全 routing 权重加权获得。
18
+
19
+ 输入:
20
+ hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
21
+ 输出:
22
+ final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
23
+ router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
24
+ """
25
  def forward(self, hidden_states: torch.Tensor):
26
+ """
27
+ 输入:
28
+ hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
29
+ 输出:
30
+ final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
31
+ router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
32
+ 实现思路:
33
+ 1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
34
+ 并计算全专家的 routing 权重(softmax 后)。
35
+ 2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
36
+ 如配置要求,归一化 top-k 概率。
37
+ 3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
38
+ 并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
39
+ 4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
40
+ 再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
41
+ 5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
42
+ 6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
43
+ """
44
+ #determine the shape of hidden_states
45
  batch_size, seq_length, hidden_dim = hidden_states.shape
46
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
47
 
48
+ # 计算路由 logits 和全专家 routing 权重
49
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
50
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
51
 
 
56
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
57
 
58
  # ---------- 稀疏计算部分 ----------
59
+ # 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
60
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
61
+ # 用于记录每个 token 对激活专家的实际输出
62
+ activated_outputs = [{} for _ in range(flat_hidden.size(0))]
63
+ # one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
 
 
 
64
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
65
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
66
 
 
 
 
 
67
  for expert_idx in range(self.num_experts):
68
  expert_layer = self.experts[expert_idx]
69
  idx, top_x = torch.where(expert_mask[expert_idx])
 
70
  if top_x.numel() > 0:
 
 
 
 
 
 
 
 
 
71
  current_state = flat_hidden[top_x] # (n, hidden_dim)
72
  current_output = expert_layer(current_state) # (n, hidden_dim)
73
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
74
  weighted_output = current_output * weight
75
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
76
+ # 保存当前 token 对该专家的实际输出
77
+ for pos, token_idx in enumerate(top_x.tolist()):
78
+ activated_outputs[token_idx][expert_idx] = current_output[pos]
79
+ # ---------- 稀疏计算结束 ----------
80
+
81
+ # ---------- Dense估计部分 ----------
82
+ # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
+ all_expert_outputs = torch.stack([expert(flat_hidden) for expert in self.experts], dim=1)
84
+ # 将 selected_experts 转换为 list,每个 token 的激活专家列表
85
+ all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
86
+
87
+ dense_outputs = []
88
+ for i in range(flat_hidden.size(0)):
89
+ dense_est = self.estimate_dense_output(
90
+ token_idx=i,
91
+ activated=all_routing[i], # 当前 token 激活的专家列表,例如 [a, b]
92
+ gate_prob=routing_weights[i], # 当前 token 的完整 routing 权重 (num_experts,)
93
+ activated_outputs=activated_outputs[i], # 当前 token 对激活专家的实际输出
94
+ all_routing=all_routing, # batch 每个 token 的激活专家列表(list of lists)
95
+ all_expert_outputs=all_expert_outputs # (B*seq_len, num_experts, hidden_dim)
96
+ )
97
+ dense_outputs.append(dense_est.unsqueeze(0))
98
+ dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
99
+ # ---------- Dense估计结束 ----------
100
+
101
+ # 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
102
+ final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
103
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
104
  return final_output, router_logits
105
 
106
+ def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
107
+ """
108
+ 对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
109
+ 参数:
110
+ token_idx: 当前 token 的索引(标量)
111
+ activated: 当前 token 激活的专家列表,例如 [1, 3]
112
+ gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
113
+ activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
114
+ all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
115
+ all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
116
+ 返回:
117
+ estimated_dense: Tensor, (hidden_dim,)
118
+ """
119
+ num_experts = gate_prob.size(0)
120
+ dense_parts = {}
121
+ # 对于激活的专家,直接使用其实际输出
122
+ for idx in activated:
123
+ dense_parts[idx] = activated_outputs[idx]
124
+ # 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
125
+ non_activated = [i for i in range(num_experts) if i not in activated]
126
+ for i in non_activated:
127
+ indices = []
128
+ for idx, r_dec in enumerate(all_routing):
129
+ if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
130
+ indices.append(idx)
131
+ if indices:
132
+ selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
133
+ estimated = selected_outputs.mean(dim=0)
134
+ else:
135
+ estimated = all_expert_outputs[:, i, :].mean(dim=0)
136
+ dense_parts[i] = estimated
137
+ # 按 gate_prob 加权求和各专家输出
138
+ estimated_dense = 0
139
+ for i in range(num_experts):
140
+ estimated_dense += gate_prob[i] * dense_parts[i]
141
+ return estimated_dense
142
+
143
 
144
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
145
  """
 
152
  base_model_prefix = "olmoe"
153
 
154
  def __init__(self, config):
 
155
  super().__init__(config)
 
 
 
 
 
 
 
 
 
 
 
 
156
  # 遍历���型中所有 decoder 层,替换每个 OlmoeSparseMoeBlock 为 DenseBackward 版本
157
  # 此处假设官方模型在 self.model.layers 中组织 decoder 层,
158
  # 且每层中 mlp 模块包含属性 sparse_moe_block。
159
  for layer in self.model.layers:
160
+ if hasattr(layer.mlp, "sparse_moe_block"):
161
+ orig_block = layer.mlp.sparse_moe_block
 
162
  # 通过直接复制原版属性创建新的块
163
  new_block = DenseBackwardOlmoeSparseMoeBlock(config) # 或其他适当参数
164
  # 然后手动复制需要共享的属性:
165
  new_block.gate = orig_block.gate
166
  new_block.experts = orig_block.experts
167
+ new_block.router = orig_block.router
168
  new_block.num_experts = orig_block.num_experts
169
  new_block.top_k = orig_block.top_k
170
  new_block.norm_topk_prob = orig_block.norm_topk_prob
171
+ layer.mlp.sparse_moe_block = new_block
 
172