autoprogrammer commited on
Commit
370ac60
·
verified ·
1 Parent(s): 56d7a8b

Update modeling_densebackward_olmoe.py

Browse files
Files changed (1) hide show
  1. modeling_densebackward_olmoe.py +80 -78
modeling_densebackward_olmoe.py CHANGED
@@ -10,42 +10,11 @@ from .configuration_densebackward_olmoe import DenseBackwardOLMoEConfig
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
13
- """
14
- 继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
15
- 前向输出依旧保持与官方相同(即稀疏计算结果),
16
- 但在反向传播时,通过直通梯度让 dense 计算的梯度传递回来,
17
- dense 输出通过对每个专家在所有 token 上进行计算,并利用全 routing 权重加权获得。
18
-
19
- 输入:
20
- hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
21
- 输出:
22
- final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
23
- router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
24
- """
25
  def forward(self, hidden_states: torch.Tensor):
26
- """
27
- 输入:
28
- hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
29
- 输出:
30
- final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
31
- router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
32
- 实现思路:
33
- 1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
34
- 并计算全专家的 routing 权重(softmax 后)。
35
- 2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
36
- 如配置要求,归一化 top-k 概率。
37
- 3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
38
- 并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
39
- 4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
40
- 再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
41
- 5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
42
- 6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
43
- """
44
- #determine the shape of hidden_states
45
  batch_size, seq_length, hidden_dim = hidden_states.shape
46
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
47
 
48
- # 计算路由 logits 和全专家 routing 权重
49
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
50
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
51
 
@@ -56,14 +25,20 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
56
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
57
 
58
  # ---------- 稀疏计算部分 ----------
59
- # 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
60
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
61
- # 用于记录每个 token 对激活专家的实际输出
62
- activated_outputs = [{} for _ in range(flat_hidden.size(0))]
63
- # one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
 
 
 
 
 
64
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
65
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
66
 
 
67
  for expert_idx in range(self.num_experts):
68
  expert_layer = self.experts[expert_idx]
69
  idx, top_x = torch.where(expert_mask[expert_idx])
@@ -73,71 +48,98 @@ class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
73
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
74
  weighted_output = current_output * weight
75
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
76
- # 保存当前 token 对该专家的实际输出
 
 
 
 
77
  for pos, token_idx in enumerate(top_x.tolist()):
78
- activated_outputs[token_idx][expert_idx] = current_output[pos]
 
 
 
 
 
 
79
  # ---------- 稀疏计算结束 ----------
80
 
81
  # ---------- Dense估计部分 ----------
82
- # 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
83
- all_expert_outputs = torch.stack([expert(flat_hidden) for expert in self.experts], dim=1)
84
- # 将 selected_experts 转换为 list,每个 token 的激活专家列表
85
  all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
86
 
 
87
  dense_outputs = []
88
- for i in range(flat_hidden.size(0)):
89
- dense_est = self.estimate_dense_output(
90
- token_idx=i,
91
- activated=all_routing[i], # 当前 token 激活的专家列表,例如 [a, b]
92
- gate_prob=routing_weights[i], # 当前 token 的完整 routing 权重 (num_experts,)
93
- activated_outputs=activated_outputs[i], # 当前 token 对激活专家的实际输出
94
- all_routing=all_routing, # 全 batch 每个 token 的激活专家列表(list of lists)
95
- all_expert_outputs=all_expert_outputs # (B*seq_len, num_experts, hidden_dim)
 
 
 
 
96
  )
97
  dense_outputs.append(dense_est.unsqueeze(0))
 
98
  dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
99
  # ---------- Dense估计结束 ----------
100
 
101
- # 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
102
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
103
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
104
  return final_output, router_logits
105
 
106
- def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
 
107
  """
108
- 对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
109
- 参数:
110
- token_idx: 当前 token 的索引(标量)
111
- activated: 当前 token 激活的专家列表,例如 [1, 3]
112
- gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
113
- activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
114
- all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
115
- all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
116
- 返回:
117
- estimated_dense: Tensor, (hidden_dim,)
118
  """
119
  num_experts = gate_prob.size(0)
120
  dense_parts = {}
121
- # 对于激活的专家,直接使用其实际输出
122
- for idx in activated:
123
- dense_parts[idx] = activated_outputs[idx]
124
- # 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
 
 
 
125
  non_activated = [i for i in range(num_experts) if i not in activated]
126
- for i in non_activated:
127
- indices = []
128
- for idx, r_dec in enumerate(all_routing):
129
- if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
130
- indices.append(idx)
131
- if indices:
132
- selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
133
- estimated = selected_outputs.mean(dim=0)
 
 
 
 
 
 
 
 
 
 
 
 
134
  else:
135
- estimated = all_expert_outputs[:, i, :].mean(dim=0)
136
- dense_parts[i] = estimated
137
- # gate_prob 加权求和各专家输出
 
 
 
 
138
  estimated_dense = 0
139
- for i in range(num_experts):
140
- estimated_dense += gate_prob[i] * dense_parts[i]
 
 
141
  return estimated_dense
142
 
143
 
 
10
 
11
 
12
  class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
 
 
 
 
 
 
 
 
 
 
 
 
13
  def forward(self, hidden_states: torch.Tensor):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  batch_size, seq_length, hidden_dim = hidden_states.shape
15
  flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
16
 
17
+ # 计算路由 logits routing 权重
18
  router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
19
  routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
20
 
 
25
  routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
26
 
27
  # ---------- 稀疏计算部分 ----------
28
+ # 初始化稀疏输出
29
  sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
30
+
31
+ # 存储所有激活信息的数据结构
32
+ num_tokens = flat_hidden.size(0)
33
+ all_activated_outputs = {} # {expert_idx: {token_idx: output_tensor}}
34
+ all_routing_indices = {} # {expert_idx: [token_indices]}
35
+ token_activated_experts = {} # {token_idx: [activated_expert_indices]}
36
+
37
+ # one-hot 编码 top-k 专家
38
  expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
39
  expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
40
 
41
+ # 稀疏计算,同时记录激活情况
42
  for expert_idx in range(self.num_experts):
43
  expert_layer = self.experts[expert_idx]
44
  idx, top_x = torch.where(expert_mask[expert_idx])
 
48
  weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
49
  weighted_output = current_output * weight
50
  sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
51
+
52
+ # 记录该专家激活的token和对应输出
53
+ all_activated_outputs[expert_idx] = {}
54
+ all_routing_indices[expert_idx] = top_x.tolist()
55
+
56
  for pos, token_idx in enumerate(top_x.tolist()):
57
+ # 记录该专家对该token的输出
58
+ all_activated_outputs[expert_idx][token_idx] = current_output[pos]
59
+
60
+ # 记录该token激活的专家
61
+ if token_idx not in token_activated_experts:
62
+ token_activated_experts[token_idx] = []
63
+ token_activated_experts[token_idx].append(expert_idx)
64
  # ---------- 稀疏计算结束 ----------
65
 
66
  # ---------- Dense估计部分 ----------
67
+ # 将activated_experts 转换为list格式,与路由权重匹配
 
 
68
  all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
69
 
70
+ # 使用已激活信息估计dense输出
71
  dense_outputs = []
72
+ for token_idx in range(num_tokens):
73
+ # 获取当前token的激活专家列表
74
+ activated = all_routing[token_idx] if token_idx in token_activated_experts else []
75
+
76
+ # 估计dense输出(只使用已经计算过的专家输出)
77
+ dense_est = self.estimate_dense_output_efficient(
78
+ token_idx=token_idx,
79
+ activated=activated,
80
+ gate_prob=routing_weights[token_idx],
81
+ all_activated_outputs=all_activated_outputs,
82
+ all_routing_indices=all_routing_indices,
83
+ token_activated_experts=token_activated_experts
84
  )
85
  dense_outputs.append(dense_est.unsqueeze(0))
86
+
87
  dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
88
  # ---------- Dense估计结束 ----------
89
 
90
+ # 使用直通梯度技巧
91
  final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
92
  final_output = final_flat.view(batch_size, seq_length, hidden_dim)
93
  return final_output, router_logits
94
 
95
+ def estimate_dense_output_efficient(self, token_idx, activated, gate_prob,
96
+ all_activated_outputs, all_routing_indices, token_activated_experts):
97
  """
98
+ 优化版本的dense输出估计,只使用已计算的专家输出
 
 
 
 
 
 
 
 
 
99
  """
100
  num_experts = gate_prob.size(0)
101
  dense_parts = {}
102
+
103
+ # 对于激活的专家,直接使用其输出
104
+ for expert_idx in activated:
105
+ if expert_idx in all_activated_outputs and token_idx in all_activated_outputs[expert_idx]:
106
+ dense_parts[expert_idx] = all_activated_outputs[expert_idx][token_idx]
107
+
108
+ # 对于未激活的专家,使用其他token的激活输出估计
109
  non_activated = [i for i in range(num_experts) if i not in activated]
110
+ for expert_idx in non_activated:
111
+ # 如果该专家没有被任何token激活,跳过
112
+ if expert_idx not in all_routing_indices or not all_routing_indices[expert_idx]:
113
+ # 使用零向量或平均值作为估计
114
+ dense_parts[expert_idx] = torch.zeros_like(next(iter(dense_parts.values()))) if dense_parts else 0
115
+ continue
116
+
117
+ # 找出激活了该专家的token,并且这些token也激活了当前token激活的某些专家
118
+ candidate_tokens = []
119
+ for other_token in all_routing_indices[expert_idx]:
120
+ # 检查other_token是否与当前token共享某些激活专家
121
+ if other_token in token_activated_experts:
122
+ common_experts = set(activated) & set(token_activated_experts[other_token])
123
+ if common_experts:
124
+ candidate_tokens.append(other_token)
125
+
126
+ # 如果找到了候选token,使用它们的输出平均值
127
+ if candidate_tokens:
128
+ expert_outputs = [all_activated_outputs[expert_idx][t] for t in candidate_tokens]
129
+ estimated = torch.stack(expert_outputs).mean(dim=0)
130
  else:
131
+ # 找不到合适的候选,使用所有激活了该专家的token
132
+ expert_outputs = [all_activated_outputs[expert_idx][t] for t in all_routing_indices[expert_idx]]
133
+ estimated = torch.stack(expert_outputs).mean(dim=0)
134
+
135
+ dense_parts[expert_idx] = estimated
136
+
137
+ # 按路由权重加权求和
138
  estimated_dense = 0
139
+ for expert_idx in range(num_experts):
140
+ if expert_idx in dense_parts:
141
+ estimated_dense += gate_prob[expert_idx] * dense_parts[expert_idx]
142
+
143
  return estimated_dense
144
 
145