Update modeling_densebackward_olmoe0125.py
Browse files- modeling_densebackward_olmoe0125.py +103 -96
modeling_densebackward_olmoe0125.py
CHANGED
@@ -14,7 +14,7 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
14 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
15 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
16 |
|
17 |
-
# 计算路由 logits
|
18 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
19 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
20 |
|
@@ -28,120 +28,127 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
28 |
# 初始化稀疏输出
|
29 |
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
|
30 |
|
31 |
-
#
|
32 |
-
|
33 |
-
|
34 |
-
all_routing_indices = {} # {expert_idx: [token_indices]}
|
35 |
-
token_activated_experts = {} # {token_idx: [activated_expert_indices]}
|
36 |
|
37 |
# one-hot 编码 top-k 专家
|
38 |
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
|
39 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
40 |
|
41 |
-
#
|
|
|
|
|
|
|
42 |
for expert_idx in range(self.num_experts):
|
43 |
expert_layer = self.experts[expert_idx]
|
44 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
45 |
if top_x.numel() > 0:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
46 |
current_state = flat_hidden[top_x] # (n, hidden_dim)
|
47 |
current_output = expert_layer(current_state) # (n, hidden_dim)
|
48 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
49 |
weighted_output = current_output * weight
|
50 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
51 |
|
52 |
-
#
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
# ----------
|
67 |
-
#
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
)
|
85 |
-
dense_outputs.append(dense_est.unsqueeze(0))
|
86 |
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
93 |
return final_output, router_logits
|
94 |
|
95 |
-
def estimate_dense_output_efficient(self, token_idx, activated, gate_prob,
|
96 |
-
all_activated_outputs, all_routing_indices, token_activated_experts):
|
97 |
-
"""
|
98 |
-
优化版本的dense输出估计,只使用已计算的专家输出
|
99 |
-
"""
|
100 |
-
num_experts = gate_prob.size(0)
|
101 |
-
dense_parts = {}
|
102 |
-
|
103 |
-
# 对于激活的专家,直接使用其输出
|
104 |
-
for expert_idx in activated:
|
105 |
-
if expert_idx in all_activated_outputs and token_idx in all_activated_outputs[expert_idx]:
|
106 |
-
dense_parts[expert_idx] = all_activated_outputs[expert_idx][token_idx]
|
107 |
-
|
108 |
-
# 对于未激活的专家,使用其他token的激活输出估计
|
109 |
-
non_activated = [i for i in range(num_experts) if i not in activated]
|
110 |
-
for expert_idx in non_activated:
|
111 |
-
# 如果该专家没有被任何token激活,跳过
|
112 |
-
if expert_idx not in all_routing_indices or not all_routing_indices[expert_idx]:
|
113 |
-
# 使用零向量或平均值作为估计
|
114 |
-
dense_parts[expert_idx] = torch.zeros_like(next(iter(dense_parts.values()))) if dense_parts else 0
|
115 |
-
continue
|
116 |
-
|
117 |
-
# 找出激活了该专家的token,并且这些token也激活了当前token激活的某些专家
|
118 |
-
candidate_tokens = []
|
119 |
-
for other_token in all_routing_indices[expert_idx]:
|
120 |
-
# 检查other_token是否与当前token共享某些激活专家
|
121 |
-
if other_token in token_activated_experts:
|
122 |
-
common_experts = set(activated) & set(token_activated_experts[other_token])
|
123 |
-
if common_experts:
|
124 |
-
candidate_tokens.append(other_token)
|
125 |
-
|
126 |
-
# 如果找到了候选token,使用它们的输出平均值
|
127 |
-
if candidate_tokens:
|
128 |
-
expert_outputs = [all_activated_outputs[expert_idx][t] for t in candidate_tokens]
|
129 |
-
estimated = torch.stack(expert_outputs).mean(dim=0)
|
130 |
-
else:
|
131 |
-
# 找不到合适的候选,使用所有激活了该专家的token
|
132 |
-
expert_outputs = [all_activated_outputs[expert_idx][t] for t in all_routing_indices[expert_idx]]
|
133 |
-
estimated = torch.stack(expert_outputs).mean(dim=0)
|
134 |
-
|
135 |
-
dense_parts[expert_idx] = estimated
|
136 |
-
|
137 |
-
# 按路由权重加权求和
|
138 |
-
estimated_dense = 0
|
139 |
-
for expert_idx in range(num_experts):
|
140 |
-
if expert_idx in dense_parts:
|
141 |
-
estimated_dense += gate_prob[expert_idx] * dense_parts[expert_idx]
|
142 |
-
|
143 |
-
return estimated_dense
|
144 |
-
|
145 |
|
146 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
147 |
"""
|
|
|
14 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
15 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
16 |
|
17 |
+
# 计算路由 logits 和路由权重
|
18 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
19 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
20 |
|
|
|
28 |
# 初始化稀疏输出
|
29 |
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
|
30 |
|
31 |
+
# 记录每个专家被激活的token
|
32 |
+
token_indices_per_expert = [[] for _ in range(self.num_experts)]
|
33 |
+
expert_outputs_dict = {i: {} for i in range(self.num_experts)}
|
|
|
|
|
34 |
|
35 |
# one-hot 编码 top-k 专家
|
36 |
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
|
37 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
38 |
|
39 |
+
# 保存每个token激活的专家列表
|
40 |
+
token_to_experts = [[] for _ in range(flat_hidden.size(0))]
|
41 |
+
|
42 |
+
# 稀疏计算
|
43 |
for expert_idx in range(self.num_experts):
|
44 |
expert_layer = self.experts[expert_idx]
|
45 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
46 |
+
|
47 |
if top_x.numel() > 0:
|
48 |
+
# 记录该专家被哪些token激活
|
49 |
+
top_x_list = top_x.tolist()
|
50 |
+
token_indices_per_expert[expert_idx].extend(top_x_list)
|
51 |
+
|
52 |
+
# 记录每个token激活了哪些专家
|
53 |
+
for token_idx in top_x_list:
|
54 |
+
token_to_experts[token_idx].append(expert_idx)
|
55 |
+
|
56 |
+
# 标准MoE前向计算
|
57 |
current_state = flat_hidden[top_x] # (n, hidden_dim)
|
58 |
current_output = expert_layer(current_state) # (n, hidden_dim)
|
59 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
60 |
weighted_output = current_output * weight
|
61 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
62 |
|
63 |
+
# 保存未加权的专家输出,用于后续估计
|
64 |
+
for i, token_idx in enumerate(top_x_list):
|
65 |
+
expert_outputs_dict[expert_idx][token_idx] = current_output[i]
|
66 |
+
|
67 |
+
# ---------- 初始化密集估计输出 ----------
|
68 |
+
# 创建一个每个token对所有专家输出的张量
|
69 |
+
dense_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
|
70 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
71 |
+
|
72 |
+
# 首先将���计算的专家输出填入
|
73 |
+
for expert_idx in range(self.num_experts):
|
74 |
+
for token_idx, output in expert_outputs_dict[expert_idx].items():
|
75 |
+
dense_outputs[token_idx, expert_idx] = output
|
76 |
+
|
77 |
+
# ---------- 添加伪梯度路径 ----------
|
78 |
+
# 对于每个未激活的专家,创建一个小的梯度路径
|
79 |
+
# 首先创建一个全零的mask标记哪些(token,expert)对已被计算
|
80 |
+
computed_mask = torch.zeros((flat_hidden.size(0), self.num_experts),
|
81 |
+
dtype=torch.bool, device=flat_hidden.device)
|
82 |
+
|
83 |
+
# 标记已计算的(token,expert)对
|
84 |
+
for expert_idx in range(self.num_experts):
|
85 |
+
for token_idx in token_indices_per_expert[expert_idx]:
|
86 |
+
computed_mask[token_idx, expert_idx] = True
|
87 |
+
|
88 |
+
# 对未计算的(token,expert)对,添加微小的伪梯度路径
|
89 |
+
# 我们使用一个小的缩放参数来确保前向计算几乎不受影响
|
90 |
+
scale = 1e-4
|
91 |
+
|
92 |
+
# 计算一个残差连接,确保梯度能够流向所有专家
|
93 |
+
for token_idx in range(flat_hidden.size(0)):
|
94 |
+
token_input = flat_hidden[token_idx:token_idx+1] # 保持维度
|
|
|
|
|
95 |
|
96 |
+
# 对每个未激活的专家创建伪梯度路径
|
97 |
+
for expert_idx in range(self.num_experts):
|
98 |
+
if not computed_mask[token_idx, expert_idx]:
|
99 |
+
# 查找是否有token激活了这个专家
|
100 |
+
if token_indices_per_expert[expert_idx]:
|
101 |
+
# 从激活该专家的token中选择一个
|
102 |
+
# 优先选择与当前token共享其他专家的token
|
103 |
+
similar_tokens = []
|
104 |
+
for other_idx in token_indices_per_expert[expert_idx]:
|
105 |
+
# 检查是否有共同激活的专家
|
106 |
+
common_experts = set(token_to_experts[token_idx]) & set(token_to_experts[other_idx])
|
107 |
+
if common_experts:
|
108 |
+
similar_tokens.append(other_idx)
|
109 |
+
|
110 |
+
if similar_tokens:
|
111 |
+
# 使用相似token的平均输出
|
112 |
+
similar_outputs = [expert_outputs_dict[expert_idx][t] for t in similar_tokens]
|
113 |
+
estimated_output = torch.stack(similar_outputs).mean(0)
|
114 |
+
else:
|
115 |
+
# 使用所有激活该专家的token的平均输出
|
116 |
+
all_outputs = [expert_outputs_dict[expert_idx][t] for t in token_indices_per_expert[expert_idx]]
|
117 |
+
estimated_output = torch.stack(all_outputs).mean(0)
|
118 |
+
|
119 |
+
# 添加微小的直接计算以维持梯度流
|
120 |
+
direct_output = self.experts[expert_idx](token_input).squeeze(0)
|
121 |
+
|
122 |
+
# 组合估计输出和直接计算
|
123 |
+
# 前向使用估计输出,反向使用直接计算
|
124 |
+
combined = estimated_output.detach() + scale * (direct_output - direct_output.detach())
|
125 |
+
dense_outputs[token_idx, expert_idx] = combined
|
126 |
+
else:
|
127 |
+
# 如果没有token激活该专家,直接进行计算但使用小缩放
|
128 |
+
# 这保证了梯度流而对前向几乎无影响
|
129 |
+
direct_output = scale * self.experts[expert_idx](token_input).squeeze(0)
|
130 |
+
dense_outputs[token_idx, expert_idx] = direct_output
|
131 |
+
|
132 |
+
# ---------- 组合输出 ----------
|
133 |
+
# 使用路由权重对每个token的所有专家输出进行加权
|
134 |
+
dense_combined = torch.zeros_like(sparse_output)
|
135 |
+
|
136 |
+
for token_idx in range(flat_hidden.size(0)):
|
137 |
+
# 对该token的所有专家输出进行加权
|
138 |
+
token_experts_output = dense_outputs[token_idx] # (num_experts, hidden_dim)
|
139 |
+
token_routing_weights = routing_weights[token_idx].unsqueeze(-1) # (num_experts, 1)
|
140 |
+
weighted_experts = token_experts_output * token_routing_weights
|
141 |
+
token_output = weighted_experts.sum(0) # (hidden_dim)
|
142 |
+
dense_combined[token_idx] = token_output
|
143 |
+
|
144 |
+
# 使用直通梯度技巧:前向使用sparse_output,反向使用dense_combined
|
145 |
+
# 增大直通梯度系数以增强梯度流
|
146 |
+
straight_through_scale = 1.0 # 可以尝试不同的值
|
147 |
+
final_flat = sparse_output.detach() + straight_through_scale * (dense_combined - dense_combined.detach())
|
148 |
+
|
149 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
150 |
return final_output, router_logits
|
151 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
152 |
|
153 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
154 |
"""
|