Update modeling_densebackward_olmoe0125.py
Browse files
modeling_densebackward_olmoe0125.py
CHANGED
@@ -26,118 +26,53 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
|
|
26 |
def forward(self, hidden_states: torch.Tensor):
|
27 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
28 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
|
|
29 |
|
|
|
30 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
31 |
-
routing_weights = F.softmax(router_logits, dim=1, dtype=
|
32 |
|
|
|
33 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
34 |
if self.norm_topk_prob:
|
35 |
routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
|
36 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
37 |
|
38 |
-
# ----------
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
activated_outputs_tensor = torch.zeros((flat_hidden.size(0), self.num_experts, hidden_dim),
|
43 |
-
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
44 |
-
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts)
|
45 |
-
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
|
46 |
-
|
47 |
for expert_idx in range(self.num_experts):
|
48 |
expert_layer = self.experts[expert_idx]
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
current_output = expert_layer(current_state) # (n, hidden_dim)
|
53 |
-
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
|
54 |
-
weighted_output = current_output * weight
|
55 |
-
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
|
56 |
-
# 直接存入 tensor:激活 token 对当前专家的输出
|
57 |
-
activated_outputs_tensor[top_x, expert_idx, :] = current_output
|
58 |
-
# ---------- 稀疏计算结束 ----------
|
59 |
-
|
60 |
-
# ---------- Dense估计部分 (向量化版本,激活专家直接使用输出) ----------
|
61 |
-
all_expert_outputs = activated_outputs_tensor # (B*seq_len, num_experts, hidden_dim)
|
62 |
-
all_routing = selected_experts.tolist() # list,每个 token 的激活专家列表
|
63 |
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
#
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
#
|
82 |
-
|
83 |
-
|
84 |
-
# 扩展 mask 和 all_expert_outputs 以便批量聚合
|
85 |
-
# all_expert_outputs: (N_tokens, num_experts, hidden_dim)
|
86 |
-
candidate_mask_exp = candidate_mask.unsqueeze(-1) # (N_tokens, N_tokens, num_experts, 1)
|
87 |
-
all_expert_outputs_exp = all_expert_outputs.unsqueeze(0) # (1, N_tokens, num_experts, hidden_dim)
|
88 |
-
|
89 |
-
# 对每个 token i 和专家 e,聚合候选 token 的输出
|
90 |
-
sum_outputs = (candidate_mask_exp * all_expert_outputs_exp).sum(dim=1) # (N_tokens, num_experts, hidden_dim)
|
91 |
-
count_outputs = candidate_mask.sum(dim=1).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
92 |
-
estimated_dense_all = torch.where(count_outputs > 0, sum_outputs / (count_outputs+1),
|
93 |
-
torch.zeros_like(sum_outputs)) # (N_tokens, num_experts, hidden_dim)
|
94 |
-
|
95 |
-
# 对于激活的专家,直接使用当前 token 的输出
|
96 |
-
# R > 0 表示激活,扩展为 (N_tokens, num_experts, 1) 与 activated_outputs_tensor 对齐
|
97 |
-
activated_mask = (R > 0).unsqueeze(-1)
|
98 |
-
estimated_dense_all = torch.where(activated_mask, activated_outputs_tensor, estimated_dense_all)
|
99 |
-
|
100 |
-
# 利用 gate_prob 加权聚合所有专家输出
|
101 |
-
gate_prob_exp = routing_weights.to(estimated_dense_all.dtype).unsqueeze(-1) # (N_tokens, num_experts, 1)
|
102 |
-
dense_outputs = (gate_prob_exp * estimated_dense_all).sum(dim=1) # (N_tokens, hidden_dim)
|
103 |
-
# ---------- Dense估计结束 (向量化版本) ----------
|
104 |
-
|
105 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
106 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
|
|
107 |
return final_output, router_logits
|
108 |
|
109 |
-
def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
|
110 |
-
num_experts = gate_prob.size(0)
|
111 |
-
dense_parts = {}
|
112 |
-
# 对于激活的专家,直接使用 tensor 的对应行
|
113 |
-
for idx in activated:
|
114 |
-
dense_parts[idx] = activated_outputs[idx]
|
115 |
-
non_activated = [i for i in range(num_experts) if i not in activated]
|
116 |
-
for i in non_activated:
|
117 |
-
indices = []
|
118 |
-
for idx, r_dec in enumerate(all_routing):
|
119 |
-
if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
|
120 |
-
indices.append(idx)
|
121 |
-
if indices:
|
122 |
-
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
|
123 |
-
mask = (selected_outputs.sum(dim=-1) != 0).to(selected_outputs.dtype).unsqueeze(-1)
|
124 |
-
if mask.sum() > 0:
|
125 |
-
estimated = (selected_outputs * mask).sum(dim=0) / mask.sum()
|
126 |
-
else:
|
127 |
-
estimated = torch.zeros_like(selected_outputs[0])
|
128 |
-
else:
|
129 |
-
all_outputs = all_expert_outputs[:, i, :]
|
130 |
-
mask = (all_outputs.sum(dim=-1) != 0).to(all_outputs.dtype).unsqueeze(-1)
|
131 |
-
if mask.sum() > 0:
|
132 |
-
estimated = (all_outputs * mask).sum(dim=0) / mask.sum()
|
133 |
-
else:
|
134 |
-
estimated = torch.zeros_like(all_outputs[0])
|
135 |
-
dense_parts[i] = estimated
|
136 |
-
estimated_dense = 0
|
137 |
-
for i in range(num_experts):
|
138 |
-
estimated_dense += gate_prob[i] * dense_parts[i]
|
139 |
-
return estimated_dense
|
140 |
-
|
141 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
142 |
"""
|
143 |
自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
|
|
|
26 |
def forward(self, hidden_states: torch.Tensor):
|
27 |
batch_size, seq_length, hidden_dim = hidden_states.shape
|
28 |
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
|
29 |
+
N_tokens = flat_hidden.size(0)
|
30 |
|
31 |
+
# 计算路由逻辑
|
32 |
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
|
33 |
+
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
|
34 |
|
35 |
+
# 选择top-k专家
|
36 |
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
|
37 |
if self.norm_topk_prob:
|
38 |
routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
|
39 |
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
|
40 |
|
41 |
+
# ---------- 真实计算所有专家输出(密集计算)----------
|
42 |
+
all_expert_outputs = torch.zeros((N_tokens, self.num_experts, hidden_dim),
|
43 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
44 |
+
|
|
|
|
|
|
|
|
|
|
|
45 |
for expert_idx in range(self.num_experts):
|
46 |
expert_layer = self.experts[expert_idx]
|
47 |
+
# 对所有token都计算当前专家的输出
|
48 |
+
expert_output = expert_layer(flat_hidden) # (N_tokens, hidden_dim)
|
49 |
+
all_expert_outputs[:, expert_idx, :] = expert_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
50 |
|
51 |
+
# ---------- 提取激活专家输出(稀疏前向)----------
|
52 |
+
# 计算稀疏输出
|
53 |
+
sparse_output = torch.zeros((N_tokens, hidden_dim),
|
54 |
+
dtype=flat_hidden.dtype, device=flat_hidden.device)
|
55 |
+
|
56 |
+
# 为每个token,提取并加权其激活专家的输出
|
57 |
+
for token_idx in range(N_tokens):
|
58 |
+
for k in range(self.top_k):
|
59 |
+
expert_idx = selected_experts[token_idx, k].item()
|
60 |
+
weight = routing_weights_topk[token_idx, k]
|
61 |
+
sparse_output[token_idx] += all_expert_outputs[token_idx, expert_idx] * weight
|
62 |
+
|
63 |
+
# ---------- 密集计算聚合(用于反向传播)----------
|
64 |
+
# 使用所有专家的输出和路由权重计算密集输出
|
65 |
+
routing_weights_expanded = routing_weights.unsqueeze(-1) # (N_tokens, num_experts, 1)
|
66 |
+
dense_outputs = (all_expert_outputs * routing_weights_expanded).sum(dim=1) # (N_tokens, hidden_dim)
|
67 |
+
|
68 |
+
# ---------- 组合稀疏前向和密集反向 ----------
|
69 |
+
# sparse_output.detach()保留稀疏前向计算图
|
70 |
+
# (dense_outputs - dense_outputs.detach())只保留密集反向梯度
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
|
72 |
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
|
73 |
+
|
74 |
return final_output, router_logits
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
|
77 |
"""
|
78 |
自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
|