autoprogrammer commited on
Commit
9d56eed
·
verified ·
1 Parent(s): 73d4cc5

Update modeling_densebackward_olmoe0125.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe0125.py +103 -96
modeling_densebackward_olmoe0125.py CHANGED
@@ -14,7 +14,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
14
  batch_size, seq_length, hidden_dim = hidden_states.shape
15
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
16
 
17
- # 计算路由 logits 和 routing 权重
18
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
19
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
20
 
@@ -28,120 +28,127 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
28
  # 初始化稀疏输出
29
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
30
 
31
- # 存储所有激活信息的数据结构
32
- num_tokens = flat_hidden.size(0)
33
- all_activated_outputs = {} # {expert_idx: {token_idx: output_tensor}}
34
- all_routing_indices = {} # {expert_idx: [token_indices]}
35
- token_activated_experts = {} # {token_idx: [activated_expert_indices]}
36
 
37
  # one-hot 编码 top-k 专家
38
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
39
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
40
 
41
- # 稀疏计算,同时记录激活情况
 
 
 
42
  for expert_idx in range(self.num_experts):
43
  expert_layer = self.experts[expert_idx]
44
  idx, top_x = torch.where(expert_mask[expert_idx])
 
45
  if top_x.numel() > 0:
 
 
 
 
 
 
 
 
 
46
  current_state = flat_hidden[top_x] # (n, hidden_dim)
47
  current_output = expert_layer(current_state) # (n, hidden_dim)
48
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
49
  weighted_output = current_output * weight
50
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
51
 
52
- # 记录该专家激活的token和对应输出
53
- all_activated_outputs[expert_idx] = {}
54
- all_routing_indices[expert_idx] = top_x.tolist()
55
-
56
- for pos, token_idx in enumerate(top_x.tolist()):
57
- # 记录该专家对该token的输出
58
- all_activated_outputs[expert_idx][token_idx] = current_output[pos]
59
-
60
- # 记录该token激活的专家
61
- if token_idx not in token_activated_experts:
62
- token_activated_experts[token_idx] = []
63
- token_activated_experts[token_idx].append(expert_idx)
64
- # ---------- 稀疏计算结束 ----------
65
-
66
- # ---------- Dense估计部分 ----------
67
- # 将activated_experts 转换为list格式,与路由权重匹配
68
- all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
69
-
70
- # 使用已激活信息估计dense输出
71
- dense_outputs = []
72
- for token_idx in range(num_tokens):
73
- # 获取当前token的激活专家列表
74
- activated = all_routing[token_idx] if token_idx in token_activated_experts else []
75
-
76
- # 估计dense输出(只使用已经计算过的专家输出)
77
- dense_est = self.estimate_dense_output_efficient(
78
- token_idx=token_idx,
79
- activated=activated,
80
- gate_prob=routing_weights[token_idx],
81
- all_activated_outputs=all_activated_outputs,
82
- all_routing_indices=all_routing_indices,
83
- token_activated_experts=token_activated_experts
84
- )
85
- dense_outputs.append(dense_est.unsqueeze(0))
86
 
87
- dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
88
- # ---------- Dense估计结束 ----------
89
-
90
- # 使用直通梯度技巧
91
- final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
92
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
93
  return final_output, router_logits
94
 
95
- def estimate_dense_output_efficient(self, token_idx, activated, gate_prob,
96
- all_activated_outputs, all_routing_indices, token_activated_experts):
97
- """
98
- 优化版本的dense输出估计,只使用已计算的专家输出
99
- """
100
- num_experts = gate_prob.size(0)
101
- dense_parts = {}
102
-
103
- # 对于激活的专家,直接使用其输出
104
- for expert_idx in activated:
105
- if expert_idx in all_activated_outputs and token_idx in all_activated_outputs[expert_idx]:
106
- dense_parts[expert_idx] = all_activated_outputs[expert_idx][token_idx]
107
-
108
- # 对于未激活的专家,使用其他token的激活输出估计
109
- non_activated = [i for i in range(num_experts) if i not in activated]
110
- for expert_idx in non_activated:
111
- # 如果该专家没有被任何token激活,跳过
112
- if expert_idx not in all_routing_indices or not all_routing_indices[expert_idx]:
113
- # 使用零向量或平均值作为估计
114
- dense_parts[expert_idx] = torch.zeros_like(next(iter(dense_parts.values()))) if dense_parts else 0
115
- continue
116
-
117
- # 找出激活了该专家的token,并且这些token也激活了当前token激活的某些专家
118
- candidate_tokens = []
119
- for other_token in all_routing_indices[expert_idx]:
120
- # 检查other_token是否与当前token共享某些激活专家
121
- if other_token in token_activated_experts:
122
- common_experts = set(activated) & set(token_activated_experts[other_token])
123
- if common_experts:
124
- candidate_tokens.append(other_token)
125
-
126
- # 如果找到了候选token,使用它们的输出平均值
127
- if candidate_tokens:
128
- expert_outputs = [all_activated_outputs[expert_idx][t] for t in candidate_tokens]
129
- estimated = torch.stack(expert_outputs).mean(dim=0)
130
- else:
131
- # 找不到合适的候选,使用所有激活了该专家的token
132
- expert_outputs = [all_activated_outputs[expert_idx][t] for t in all_routing_indices[expert_idx]]
133
- estimated = torch.stack(expert_outputs).mean(dim=0)
134
-
135
- dense_parts[expert_idx] = estimated
136
-
137
- # 按路由权重加权求和
138
- estimated_dense = 0
139
- for expert_idx in range(num_experts):
140
- if expert_idx in dense_parts:
141
- estimated_dense += gate_prob[expert_idx] * dense_parts[expert_idx]
142
-
143
- return estimated_dense
144
-
145
 
146
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
147
  """
 
14
  batch_size, seq_length, hidden_dim = hidden_states.shape
15
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
16
 
17
+ # 计算路由 logits 和路由权重
18
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
19
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
20
 
 
28
  # 初始化稀疏输出
29
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
30
 
31
+ # 记录每个专家被激活的token
32
+ token_indices_per_expert = [[] for _ in range(self.num_experts)]
33
+ expert_outputs_dict = {i: {} for i in range(self.num_experts)}
 
 
34
 
35
  # one-hot 编码 top-k 专家
36
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
37
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
38
 
39
+ # 保存每个token激活的专家列表
40
+ token_to_experts = [[] for _ in range(flat_hidden.size(0))]
41
+
42
+ # 稀疏计算
43
  for expert_idx in range(self.num_experts):
44
  expert_layer = self.experts[expert_idx]
45
  idx, top_x = torch.where(expert_mask[expert_idx])
46
+
47
  if top_x.numel() > 0:
48
+ # 记录该专家被哪些token激活
49
+ top_x_list = top_x.tolist()
50
+ token_indices_per_expert[expert_idx].extend(top_x_list)
51
+
52
+ # 记录每个token激活了哪些专家
53
+ for token_idx in top_x_list:
54
+ token_to_experts[token_idx].append(expert_idx)
55
+
56
+ # 标准MoE前向计算
57
  current_state = flat_hidden[top_x] # (n, hidden_dim)
58
  current_output = expert_layer(current_state) # (n, hidden_dim)
59
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
60
  weighted_output = current_output * weight
61
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
62
 
63
+ # 保存未加权的专家输出,用于后续估计
64
+ for i, token_idx in enumerate(top_x_list):
65
+ expert_outputs_dict[expert_idx][token_idx] = current_output[i]
66
+
67
+ # ---------- 初始化密集估计输出 ----------
68
+ # 创建一个每个token对所有专家输出的张量
69
+ dense_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
70
+ dtype=flat_hidden.dtype, device=flat_hidden.device)
71
+
72
+ # 首先将���计算的专家输出填入
73
+ for expert_idx in range(self.num_experts):
74
+ for token_idx, output in expert_outputs_dict[expert_idx].items():
75
+ dense_outputs[token_idx, expert_idx] = output
76
+
77
+ # ---------- 添加伪梯度路径 ----------
78
+ # 对于每个未激活的专家,创建一个小的梯度路径
79
+ # 首先创建一个全零的mask标记哪些(token,expert)对已被计算
80
+ computed_mask = torch.zeros((flat_hidden.size(0), self.num_experts),
81
+ dtype=torch.bool, device=flat_hidden.device)
82
+
83
+ # 标记已计算的(token,expert)
84
+ for expert_idx in range(self.num_experts):
85
+ for token_idx in token_indices_per_expert[expert_idx]:
86
+ computed_mask[token_idx, expert_idx] = True
87
+
88
+ # 对未计算的(token,expert)对,添加微小的伪梯度路径
89
+ # 我们使用一个小的缩放参数来确保前向计算几乎不受影响
90
+ scale = 1e-4
91
+
92
+ # 计算一个残差连接,确保梯度能够流向所有专家
93
+ for token_idx in range(flat_hidden.size(0)):
94
+ token_input = flat_hidden[token_idx:token_idx+1] # 保持维度
 
 
95
 
96
+ # 对每个未激活的专家创建伪梯度路径
97
+ for expert_idx in range(self.num_experts):
98
+ if not computed_mask[token_idx, expert_idx]:
99
+ # 查找是否有token激活了这个专家
100
+ if token_indices_per_expert[expert_idx]:
101
+ # 从激活该专家的token中选择一个
102
+ # 优先选择与当前token共享其他专家的token
103
+ similar_tokens = []
104
+ for other_idx in token_indices_per_expert[expert_idx]:
105
+ # 检查是否有共同激活的专家
106
+ common_experts = set(token_to_experts[token_idx]) & set(token_to_experts[other_idx])
107
+ if common_experts:
108
+ similar_tokens.append(other_idx)
109
+
110
+ if similar_tokens:
111
+ # 使用相似token的平均输出
112
+ similar_outputs = [expert_outputs_dict[expert_idx][t] for t in similar_tokens]
113
+ estimated_output = torch.stack(similar_outputs).mean(0)
114
+ else:
115
+ # 使用所有激活该专家的token的平均输出
116
+ all_outputs = [expert_outputs_dict[expert_idx][t] for t in token_indices_per_expert[expert_idx]]
117
+ estimated_output = torch.stack(all_outputs).mean(0)
118
+
119
+ # 添加微小的直接计算以维持梯度流
120
+ direct_output = self.experts[expert_idx](token_input).squeeze(0)
121
+
122
+ # 组合估计输出和直接计算
123
+ # 前向使用估计输出,反向使用直接计算
124
+ combined = estimated_output.detach() + scale * (direct_output - direct_output.detach())
125
+ dense_outputs[token_idx, expert_idx] = combined
126
+ else:
127
+ # 如果没有token激活该专家,直接进行计算但使用小缩放
128
+ # 这保证了梯度流而对前向几乎无影响
129
+ direct_output = scale * self.experts[expert_idx](token_input).squeeze(0)
130
+ dense_outputs[token_idx, expert_idx] = direct_output
131
+
132
+ # ---------- 组合输出 ----------
133
+ # 使用路由权重对每个token的所有专家输出进行加权
134
+ dense_combined = torch.zeros_like(sparse_output)
135
+
136
+ for token_idx in range(flat_hidden.size(0)):
137
+ # 对该token的所有专家输出进行加权
138
+ token_experts_output = dense_outputs[token_idx] # (num_experts, hidden_dim)
139
+ token_routing_weights = routing_weights[token_idx].unsqueeze(-1) # (num_experts, 1)
140
+ weighted_experts = token_experts_output * token_routing_weights
141
+ token_output = weighted_experts.sum(0) # (hidden_dim)
142
+ dense_combined[token_idx] = token_output
143
+
144
+ # 使用直通梯度技巧:前向使用sparse_output,反向使用dense_combined
145
+ # 增大直通梯度系数以增强梯度流
146
+ straight_through_scale = 1.0 # 可以尝试不同的值
147
+ final_flat = sparse_output.detach() + straight_through_scale * (dense_combined - dense_combined.detach())
148
+
149
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
150
  return final_output, router_logits
151
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
154
  """