Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -10,6 +10,7 @@ from .configuration_densebackward_olmoe0125 import DenseBackwardOLMoEConfig
|
|
10 |
|
11 |
|
12 |
class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
|
13 |
"""
|
14 |
继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
|
15 |
前向输出依旧保持与官方相同(即稀疏计算结果),
|
@@ -23,48 +24,27 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
23 |
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
24 |
"""
|
25 |
def forward(self, hidden_states: torch.Tensor):
|
26 |
-
"""
|
27 |
-
输入:
|
28 |
-
hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
29 |
-
输出:
|
30 |
-
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
|
31 |
-
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
32 |
-
实现思路:
|
33 |
-
1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
|
34 |
-
并计算全专家的 routing 权重(softmax 后)。
|
35 |
-
2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
|
36 |
-
如配置要求,归一化 top-k 概率。
|
37 |
-
3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
|
38 |
-
并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
|
39 |
-
4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
|
40 |
-
再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
|
41 |
-
5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
|
42 |
-
6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
|
43 |
-
"""
|
44 |
-
#determine the shape of hidden_states
|
45 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
46 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
47 |
|
48 |
-
# 计算路由 logits 和全专家 routing 权重
|
49 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
50 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
51 |
|
52 |
-
# Top-k 选择
|
53 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
54 |
if self.norm_topk_prob:
|
55 |
routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
|
56 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
57 |
|
58 |
# ---------- 稀疏计算部分 ----------
|
59 |
-
|
60 |
-
|
61 |
-
#
|
62 |
-
|
63 |
-
|
64 |
-
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
|
65 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
66 |
|
67 |
-
for expert_idx in range(self.num_experts):
|
68 |
expert_layer = self.experts[expert_idx]
|
69 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
70 |
if top_x.numel() > 0:
|
@@ -73,75 +53,77 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
73 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
74 |
weighted_output = current_output * weight
|
75 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
76 |
-
#
|
77 |
-
|
78 |
-
activated_outputs[token_idx][expert_idx] = current_output[pos]
|
79 |
# ---------- 稀疏计算结束 ----------
|
80 |
|
81 |
-
# ---------- Dense估计部分 ----------
|
82 |
-
|
83 |
-
#
|
84 |
-
all_expert_outputs = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
|
85 |
-
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
86 |
-
# 填入已激活专家的输出
|
87 |
-
for i in range(flat_hidden.size(0)):
|
88 |
-
for expert_idx in activated_outputs[i].keys():
|
89 |
-
all_expert_outputs[i, expert_idx] = activated_outputs[i][expert_idx]
|
90 |
-
# 将 selected_experts 转换为 list,每个 token 的激活专家列表
|
91 |
-
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
|
92 |
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
#
|
106 |
-
|
107 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
108 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
109 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
110 |
return final_output, router_logits
|
111 |
|
112 |
def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
|
113 |
-
"""
|
114 |
-
对于当前 token,根��� mini-batch 中的信息估计 dense 输出。
|
115 |
-
参数:
|
116 |
-
token_idx: 当前 token 的索引(标量)
|
117 |
-
activated: 当前 token 激活的专家列表,例如 [1, 3]
|
118 |
-
gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
|
119 |
-
activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
|
120 |
-
all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
|
121 |
-
all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
|
122 |
-
返回:
|
123 |
-
estimated_dense: Tensor, (hidden_dim,)
|
124 |
-
"""
|
125 |
num_experts = gate_prob.size(0)
|
126 |
dense_parts = {}
|
127 |
-
#
|
128 |
for idx in activated:
|
129 |
dense_parts[idx] = activated_outputs[idx]
|
130 |
-
# 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
|
131 |
non_activated = [i for i in range(num_experts) if i not in activated]
|
132 |
-
for i in non_activated:
|
133 |
indices = []
|
134 |
for idx, r_dec in enumerate(all_routing):
|
135 |
if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
|
136 |
indices.append(idx)
|
137 |
if indices:
|
138 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
139 |
-
# 只计算非零值的平均值
|
140 |
mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
|
141 |
if mask.sum() > 0:
|
142 |
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
143 |
else:
|
144 |
-
# 如果全是零,返回零向量
|
145 |
estimated = torch.zeros_like(selected_outputs[0])
|
146 |
else:
|
147 |
all_outputs = all_expert_outputs[:, i, :]
|
@@ -149,16 +131,13 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
149 |
if mask.sum() > 0:
|
150 |
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
151 |
else:
|
152 |
-
# 如果全是零,返回零向量
|
153 |
estimated = torch.zeros_like(all_outputs[0])
|
154 |
dense_parts[i] = estimated
|
155 |
-
# 按 gate_prob 加权求和各专家输出
|
156 |
estimated_dense = 0
|
157 |
for i in range(num_experts):
|
158 |
estimated_dense += gate_prob[i] * dense_parts[i]
|
159 |
return estimated_dense
|
160 |
|
161 |
-
|
162 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
163 |
"""
|
164 |
自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
|
|
|
10 |
|
11 |
|
12 |
class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
13 |
+
|
14 |
"""
|
15 |
继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
|
16 |
前向输出依旧保持与官方相同(即稀疏计算结果),
|
|
|
24 |
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
|
25 |
"""
|
26 |
def forward(self, hidden_states: torch.Tensor):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
27 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
28 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
29 |
|
|
|
30 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
31 |
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
32 |
|
|
|
33 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
34 |
if self.norm_topk_prob:
|
35 |
routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
|
36 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
37 |
|
38 |
# ---------- 稀疏计算部分 ----------
|
39 |
+
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim),
|
40 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
41 |
+
# 使用 tensor 存储,每个 token 对各专家的输出:形状 (B*seq_len, num_experts, hidden_dim)
|
42 |
+
activated_outputs_tensor = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
|
43 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
44 |
+
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
|
45 |
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
46 |
|
47 |
+
for expert_idx in tqdm(range(self.num_experts), desc="修改版本-专家循环"):
|
48 |
expert_layer = self.experts[expert_idx]
|
49 |
idx, top_x = torch.where(expert_mask[expert_idx])
|
50 |
if top_x.numel() > 0:
|
|
|
53 |
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
54 |
weighted_output = current_output * weight
|
55 |
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
56 |
+
# 直接存入 tensor:激活 token 对当前专家的输出
|
57 |
+
activated_outputs_tensor[top_x, expert_idx, :] = current_output
|
|
|
58 |
# ---------- 稀疏计算结束 ----------
|
59 |
|
60 |
+
# ---------- Dense估计部分 (向量化版本,激活专家直接使用输出) ----------
|
61 |
+
all_expert_outputs = activated_outputs_tensor # (B*seq_len, num_experts, hidden_dim)
|
62 |
+
all_routing = selected_experts.tolist() # list,每个 token 的激活专家列表
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
N_tokens = flat_hidden.size(0)
|
65 |
+
num_experts = self.num_experts
|
66 |
+
|
67 |
+
# 将 selected_experts 转换为 one-hot 二值矩阵 R: (N_tokens, num_experts)
|
68 |
+
R = F.one_hot(selected_experts, num_classes=num_experts).float() # (N_tokens, top_k, num_experts)
|
69 |
+
R = R.sum(dim=1) # (N_tokens, num_experts),激活的专家位置值大于0
|
70 |
+
|
71 |
+
# 计算 token 之间共享激活情况 S: (N_tokens, N_tokens)
|
72 |
+
S = torch.matmul(R, R.t()) # S[i,j] > 0 表示 token i 和 token j 至少共享一个激活专家
|
73 |
+
S = S * (1 - torch.eye(N_tokens, device=S.device)) # 去除自身
|
74 |
+
|
75 |
+
# 构造候选 mask M: (N_tokens, N_tokens, num_experts)
|
76 |
+
# M[i, j, e] = 1 表示 token j 激活了专家 e 且 token i 与 token j 至少共享一个激活专家
|
77 |
+
R_expanded = R.unsqueeze(0).expand(N_tokens, -1, -1) # (N_tokens, N_tokens, num_experts)
|
78 |
+
S_expanded = S.unsqueeze(-1) # (N_tokens, N_tokens, 1)
|
79 |
+
candidate_mask = ((R_expanded > 0) & (S_expanded > 0)).float() # (N_tokens, N_tokens, num_experts)
|
80 |
+
|
81 |
+
# 对于数值稳定,排除 token 自身(对角线置0)
|
82 |
+
candidate_mask = candidate_mask * (1 - torch.eye(N_tokens, device=candidate_mask.device).unsqueeze(-1))
|
83 |
+
|
84 |
+
# 扩展 mask 和 all_expert_outputs 以便批量聚合
|
85 |
+
# all_expert_outputs: (N_tokens, num_experts, hidden_dim)
|
86 |
+
candidate_mask_exp = candidate_mask.unsqueeze(-1) # (N_tokens, N_tokens, num_experts, 1)
|
87 |
+
all_expert_outputs_exp = all_expert_outputs.unsqueeze(0) # (1, N_tokens, num_experts, hidden_dim)
|
88 |
+
|
89 |
+
# 对每个 token i 和专家 e,聚合候选 token 的输出
|
90 |
+
sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
|
91 |
+
count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
92 |
+
estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / count_outputs,
|
93 |
+
torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
|
94 |
+
|
95 |
+
# 对于激活的专家,直接使用当前 token 的输出
|
96 |
+
# R > 0 表示激活,扩展为 (N_tokens, num_experts, 1) 与 activated_outputs_tensor 对齐
|
97 |
+
activated_mask = (R > 0).unsqueeze(-1)
|
98 |
+
estimated_dense_all = torch.where(activated_mask, activated_outputs_tensor, estimated_dense_all)
|
99 |
+
|
100 |
+
# 利用 gate_prob 加权聚合所有专家输出
|
101 |
+
gate_prob_exp = routing_weights.to(estimated_dense_all.dtype).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
102 |
+
dense_outputs = (gate_prob_exp * estimated_dense_all).sum(dim=1) # (N_tokens, hidden_dim)
|
103 |
+
# ---------- Dense估计结束 (向量化版本) ----------
|
104 |
+
|
105 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
106 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
107 |
return final_output, router_logits
|
108 |
|
109 |
def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
num_experts = gate_prob.size(0)
|
111 |
dense_parts = {}
|
112 |
+
# 对于激活的专家,直接使用 tensor 的对应行
|
113 |
for idx in activated:
|
114 |
dense_parts[idx] = activated_outputs[idx]
|
|
|
115 |
non_activated = [i for i in range(num_experts) if i not in activated]
|
116 |
+
for i in tqdm(non_activated, desc=f"修改版本-Token {token_idx} 非激活专家估计"):
|
117 |
indices = []
|
118 |
for idx, r_dec in enumerate(all_routing):
|
119 |
if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
|
120 |
indices.append(idx)
|
121 |
if indices:
|
122 |
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
|
|
123 |
mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
|
124 |
if mask.sum() > 0:
|
125 |
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
126 |
else:
|
|
|
127 |
estimated = torch.zeros_like(selected_outputs[0])
|
128 |
else:
|
129 |
all_outputs = all_expert_outputs[:, i, :]
|
|
|
131 |
if mask.sum() > 0:
|
132 |
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
133 |
else:
|
|
|
134 |
estimated = torch.zeros_like(all_outputs[0])
|
135 |
dense_parts[i] = estimated
|
|
|
136 |
estimated_dense = 0
|
137 |
for i in range(num_experts):
|
138 |
estimated_dense += gate_prob[i] * dense_parts[i]
|
139 |
return estimated_dense
|
140 |
|
|
|
141 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
142 |
"""
|
143 |
自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
|