Update modeling_densebackward_olmoe.py
Browse files- modeling_densebackward_olmoe.py +80 -78
modeling_densebackward_olmoe.py
CHANGED
@@ -10,42 +10,11 @@ from .configuration_densebackward_olmoe import DenseBackwardOLMoEConfig
|
|
10 |
|
11 |
|
12 |
class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
13 |
-
"""
|
14 |
-
继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
|
15 |
-
前向输出依旧保持与官方相同(即稀疏计算结果),
|
16 |
-
但在反向传播时,通过直通梯度让 dense 计算的梯度传递回来,
|
17 |
-
dense 输出通过对每个专家在所有 token 上进行计算,并利用全 routing 权重加权获得。
|
18 |
-
|
19 |
-
输入:
|
20 |
-
hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
21 |
-
输出:
|
22 |
-
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
23 |
-
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
24 |
-
"""
|
25 |
def forward(self, hidden_states: torch.Tensor):
|
26 |
-
"""
|
27 |
-
输入:
|
28 |
-
hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
29 |
-
输出:
|
30 |
-
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
31 |
-
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
32 |
-
实现思路:
|
33 |
-
1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
|
34 |
-
并计算全专家的 routing 权重(softmax 后)。
|
35 |
-
2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
|
36 |
-
如配置要求,归一化 top-k 概率。
|
37 |
-
3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
|
38 |
-
并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
|
39 |
-
4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
|
40 |
-
再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
|
41 |
-
5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
|
42 |
-
6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
|
43 |
-
"""
|
44 |
-
#determine the shape of hidden_states
|
45 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
46 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
47 |
|
48 |
-
# 计算路由 logits
|
49 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
50 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
51 |
|
@@ -56,14 +25,20 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
56 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
57 |
|
58 |
# ---------- 稀疏计算部分 ----------
|
59 |
-
#
|
60 |
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
|
61 |
-
|
62 |
-
|
63 |
-
|
|
|
|
|
|
|
|
|
|
|
64 |
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
|
65 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
66 |
|
|
|
67 |
for expert_idx in range(self.num_experts):
|
68 |
expert_layer = self.experts[expert_idx]
|
69 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
@@ -73,71 +48,98 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
73 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
74 |
weighted_output = current_output * weight
|
75 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
76 |
-
|
|
|
|
|
|
|
|
|
77 |
for pos, token_idx in enumerate(top_x.tolist()):
|
78 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
79 |
# ---------- 稀疏计算结束 ----------
|
80 |
|
81 |
# ---------- Dense估计部分 ----------
|
82 |
-
#
|
83 |
-
all_expert_outputs = torch.stack([expert(flat_hidden) for expert in self.experts], dim=1)
|
84 |
-
# 将 selected_experts 转换为 list,每个 token 的激活专家列表
|
85 |
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
|
86 |
|
|
|
87 |
dense_outputs = []
|
88 |
-
for
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
|
|
|
|
|
|
|
|
96 |
)
|
97 |
dense_outputs.append(dense_est.unsqueeze(0))
|
|
|
98 |
dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
|
99 |
# ---------- Dense估计结束 ----------
|
100 |
|
101 |
-
#
|
102 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
103 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
104 |
return final_output, router_logits
|
105 |
|
106 |
-
def
|
|
|
107 |
"""
|
108 |
-
|
109 |
-
参数:
|
110 |
-
token_idx: 当前 token 的索引(标量)
|
111 |
-
activated: 当前 token 激活的专家列表,例如 [1, 3]
|
112 |
-
gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
|
113 |
-
activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
|
114 |
-
all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
|
115 |
-
all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
|
116 |
-
返回:
|
117 |
-
estimated_dense: Tensor, (hidden_dim,)
|
118 |
"""
|
119 |
num_experts = gate_prob.size(0)
|
120 |
dense_parts = {}
|
121 |
-
|
122 |
-
|
123 |
-
|
124 |
-
|
|
|
|
|
|
|
125 |
non_activated = [i for i in range(num_experts) if i not in activated]
|
126 |
-
for
|
127 |
-
|
128 |
-
|
129 |
-
|
130 |
-
|
131 |
-
|
132 |
-
|
133 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
134 |
else:
|
135 |
-
|
136 |
-
|
137 |
-
|
|
|
|
|
|
|
|
|
138 |
estimated_dense = 0
|
139 |
-
for
|
140 |
-
|
|
|
|
|
141 |
return estimated_dense
|
142 |
|
143 |
|
|
|
10 |
|
11 |
|
12 |
class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
def forward(self, hidden_states: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
15 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
16 |
|
17 |
+
# 计算路由 logits 和 routing 权重
|
18 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
19 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
20 |
|
|
|
25 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
26 |
|
27 |
# ---------- 稀疏计算部分 ----------
|
28 |
+
# 初始化稀疏输出
|
29 |
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
|
30 |
+
|
31 |
+
# 存储所有激活信息的数据结构
|
32 |
+
num_tokens = flat_hidden.size(0)
|
33 |
+
all_activated_outputs = {} # {expert_idx: {token_idx: output_tensor}}
|
34 |
+
all_routing_indices = {} # {expert_idx: [token_indices]}
|
35 |
+
token_activated_experts = {} # {token_idx: [activated_expert_indices]}
|
36 |
+
|
37 |
+
# one-hot 编码 top-k 专家
|
38 |
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
|
39 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
40 |
|
41 |
+
# 稀疏计算,同时记录激活情况
|
42 |
for expert_idx in range(self.num_experts):
|
43 |
expert_layer = self.experts[expert_idx]
|
44 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
|
|
48 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
49 |
weighted_output = current_output * weight
|
50 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
51 |
+
|
52 |
+
# 记录该专家激活的token和对应输出
|
53 |
+
all_activated_outputs[expert_idx] = {}
|
54 |
+
all_routing_indices[expert_idx] = top_x.tolist()
|
55 |
+
|
56 |
for pos, token_idx in enumerate(top_x.tolist()):
|
57 |
+
# 记录该专家对该token的输出
|
58 |
+
all_activated_outputs[expert_idx][token_idx] = current_output[pos]
|
59 |
+
|
60 |
+
# 记录该token激活的专家
|
61 |
+
if token_idx not in token_activated_experts:
|
62 |
+
token_activated_experts[token_idx] = []
|
63 |
+
token_activated_experts[token_idx].append(expert_idx)
|
64 |
# ---------- 稀疏计算结束 ----------
|
65 |
|
66 |
# ---------- Dense估计部分 ----------
|
67 |
+
# 将activated_experts 转换为list格式,与路由权重匹配
|
|
|
|
|
68 |
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
|
69 |
|
70 |
+
# 使用已激活信息估计dense输出
|
71 |
dense_outputs = []
|
72 |
+
for token_idx in range(num_tokens):
|
73 |
+
# 获取当前token的激活专家列表
|
74 |
+
activated = all_routing[token_idx] if token_idx in token_activated_experts else []
|
75 |
+
|
76 |
+
# 估计dense输出(只使用已经计算过的专家输出)
|
77 |
+
dense_est = self.estimate_dense_output_efficient(
|
78 |
+
token_idx=token_idx,
|
79 |
+
activated=activated,
|
80 |
+
gate_prob=routing_weights[token_idx],
|
81 |
+
all_activated_outputs=all_activated_outputs,
|
82 |
+
all_routing_indices=all_routing_indices,
|
83 |
+
token_activated_experts=token_activated_experts
|
84 |
)
|
85 |
dense_outputs.append(dense_est.unsqueeze(0))
|
86 |
+
|
87 |
dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
|
88 |
# ---------- Dense估计结束 ----------
|
89 |
|
90 |
+
# 使用直通梯度技巧
|
91 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
92 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
93 |
return final_output, router_logits
|
94 |
|
95 |
+
def estimate_dense_output_efficient(self, token_idx, activated, gate_prob,
|
96 |
+
all_activated_outputs, all_routing_indices, token_activated_experts):
|
97 |
"""
|
98 |
+
优化版本的dense输出估计,只使用已计算的专家输出
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
"""
|
100 |
num_experts = gate_prob.size(0)
|
101 |
dense_parts = {}
|
102 |
+
|
103 |
+
# 对于激活的专家,直接使用其输出
|
104 |
+
for expert_idx in activated:
|
105 |
+
if expert_idx in all_activated_outputs and token_idx in all_activated_outputs[expert_idx]:
|
106 |
+
dense_parts[expert_idx] = all_activated_outputs[expert_idx][token_idx]
|
107 |
+
|
108 |
+
# 对于未激活的专家,使用其他token的激活输出估计
|
109 |
non_activated = [i for i in range(num_experts) if i not in activated]
|
110 |
+
for expert_idx in non_activated:
|
111 |
+
# 如果该专家没有被任何token激活,跳过
|
112 |
+
if expert_idx not in all_routing_indices or not all_routing_indices[expert_idx]:
|
113 |
+
# 使用零向量或平均值作为估计
|
114 |
+
dense_parts[expert_idx] = torch.zeros_like(next(iter(dense_parts.values()))) if dense_parts else 0
|
115 |
+
continue
|
116 |
+
|
117 |
+
# 找出激活了该专家的token,并且这些token也激活了当前token激活的某些专家
|
118 |
+
candidate_tokens = []
|
119 |
+
for other_token in all_routing_indices[expert_idx]:
|
120 |
+
# 检查other_token是否与当前token共享某些激活专家
|
121 |
+
if other_token in token_activated_experts:
|
122 |
+
common_experts = set(activated) & set(token_activated_experts[other_token])
|
123 |
+
if common_experts:
|
124 |
+
candidate_tokens.append(other_token)
|
125 |
+
|
126 |
+
# 如果找到了候选token,使用它们的输出平均值
|
127 |
+
if candidate_tokens:
|
128 |
+
expert_outputs = [all_activated_outputs[expert_idx][t] for t in candidate_tokens]
|
129 |
+
estimated = torch.stack(expert_outputs).mean(dim=0)
|
130 |
else:
|
131 |
+
# 找不到合适的候选,使用所有激活了该专家的token
|
132 |
+
expert_outputs = [all_activated_outputs[expert_idx][t] for t in all_routing_indices[expert_idx]]
|
133 |
+
estimated = torch.stack(expert_outputs).mean(dim=0)
|
134 |
+
|
135 |
+
dense_parts[expert_idx] = estimated
|
136 |
+
|
137 |
+
# 按路由权重加权求和
|
138 |
estimated_dense = 0
|
139 |
+
for expert_idx in range(num_experts):
|
140 |
+
if expert_idx in dense_parts:
|
141 |
+
estimated_dense += gate_prob[expert_idx] * dense_parts[expert_idx]
|
142 |
+
|
143 |
return estimated_dense
|
144 |
|
145 |
|