File size: 11,120 Bytes
94724ad |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 |
# my_custom_olmoe/modeling_custom.py
import torch
import torch.nn as nn
import torch.nn.functional as F
# 导入官方实现(注意根据你的 transformers 版本调整导入路径)
from transformers.models.olmoe.modeling_olmoe import OlmoeForCausalLM, OlmoeSparseMoeBlock, OlmoeMLP
from configuration_custom import DenseBackwardOLMoEConfig
class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock):
"""
继承自官方 OlmoeSparseMoeBlock,实现 dense backward 功能:
前向输出依旧保持与官方相同(即稀疏计算结果),
但在反向传播时,通过直通梯度让 dense 计算的梯度传递回来,
dense 输出通过对每个专家在所有 token 上进行计算,并利用全 routing 权重加权获得。
输入:
hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
输出:
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
"""
def forward(self, hidden_states: torch.Tensor):
"""
输入:
hidden_states: Tensor, shape (batch_size, sequence_length, hidden_dim)
输出:
final_output: Tensor, shape (batch_size, sequence_length, hidden_dim)
router_logits: Tensor, shape (batch_size * sequence_length, num_experts)
实现思路:
1. 将输入展平为 (B*seq_len, hidden_dim),通过 self.gate 得到 router_logits,
并计算全专家的 routing 权重(softmax 后)。
2. 对 routing 权重取 top-k,得到 routing_weights_topk 与 selected_experts;
如配置要求,归一化 top-k 概率。
3. 稀疏计算部分:仅计算每个 token 对于 top-k 专家的输出,
并累加得到 sparse_output(保留原版计算流程,同时记录激活专家的实际输出)。
4. Dense 估计部分:先计算所有专家对所有 token 的输出(all_expert_outputs),
再逐 token 调用 estimate_dense_output 得到 dense 输出(dense_estimated)。
5. 使用直通梯度技巧:前向输出用 sparse_output,但梯度来源于 dense_estimated。
6. 最后 reshape 为 (batch_size, sequence_length, hidden_dim) 并返回 final_output 及 router_logits.
"""
#determine the shape of hidden_states
batch_size, seq_length, hidden_dim = hidden_states.shape
flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim)
# 计算路由 logits 和全专家 routing 权重
router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts)
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts)
# Top-k 选择
routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)
if self.norm_topk_prob:
routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True)
routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype)
# ---------- 稀疏计算部分 ----------
# 初始化稀疏输出,shape: (B*seq_len, hidden_dim)
sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device)
# 用于记录每个 token 对激活专家的实际输出
activated_outputs = [{} for _ in range(flat_hidden.size(0))]
# one-hot 编码 top-k 专家,shape: (B*seq_len, top_k, num_experts)
expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts)
expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len)
for expert_idx in range(self.num_experts):
expert_layer = self.experts[expert_idx]
idx, top_x = torch.where(expert_mask[expert_idx])
if top_x.numel() > 0:
current_state = flat_hidden[top_x] # (n, hidden_dim)
current_output = expert_layer(current_state) # (n, hidden_dim)
weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1)
weighted_output = current_output * weight
sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype))
# 保存当前 token 对该专家的实际输出
for pos, token_idx in enumerate(top_x.tolist()):
activated_outputs[token_idx][expert_idx] = current_output[pos]
# ---------- 稀疏计算结束 ----------
# ---------- Dense估计部分 ----------
# 计算所有专家对所有 token 的 dense 输出,shape: (B*seq_len, num_experts, hidden_dim)
all_expert_outputs = torch.stack([expert(flat_hidden) for expert in self.experts], dim=1)
# 将 selected_experts 转换为 list,每个 token 的激活专家列表
all_routing = selected_experts.tolist() # 长度为 (B*seq_len)
dense_outputs = []
for i in range(flat_hidden.size(0)):
dense_est = self.estimate_dense_output(
token_idx=i,
activated=all_routing[i], # 当前 token 激活的专家列表,例如 [a, b]
gate_prob=routing_weights[i], # 当前 token 的完整 routing 权重 (num_experts,)
activated_outputs=activated_outputs[i], # 当前 token 对激活专家的实际输出
all_routing=all_routing, # 全 batch 每个 token 的激活专家列表(list of lists)
all_expert_outputs=all_expert_outputs # (B*seq_len, num_experts, hidden_dim)
)
dense_outputs.append(dense_est.unsqueeze(0))
dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim)
# ---------- Dense估计结束 ----------
# 使用直通梯度:前向输出用稀疏结果,但反向传播时梯度来源于 dense 估计
final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach())
final_output = final_flat.view(batch_size, seq_length, hidden_dim)
return final_output, router_logits
def estimate_dense_output(self, token_idx, activated, gate_prob, activated_outputs, all_routing, all_expert_outputs):
"""
对于当前 token,根据 mini-batch 中的信息估计 dense 输出。
参数:
token_idx: 当前 token 的索引(标量)
activated: 当前 token 激活的专家列表,例如 [1, 3]
gate_prob: 当前 token 的 routing 权重,形状 (num_experts,)
activated_outputs: dict,当前 token 对激活专家的实际输出,形状 (hidden_dim,)
all_routing: list,每个 token 的激活专家列表(长度为 N,每个元素为 list)
all_expert_outputs: Tensor, (N, num_experts, hidden_dim)
返回:
estimated_dense: Tensor, (hidden_dim,)
"""
num_experts = gate_prob.size(0)
dense_parts = {}
# 对于激活的专家,直接使用其实际输出
for idx in activated:
dense_parts[idx] = activated_outputs[idx]
# 对于未激活的专家,使用 mini-batch 中其他 token 的输出估计
non_activated = [i for i in range(num_experts) if i not in activated]
for i in non_activated:
indices = []
for idx, r_dec in enumerate(all_routing):
if (i in r_dec) and (len(set(r_dec) & set(activated)) > 0):
indices.append(idx)
if indices:
selected_outputs = all_expert_outputs[indices, i, :] # (n, hidden_dim)
estimated = selected_outputs.mean(dim=0)
else:
estimated = all_expert_outputs[:, i, :].mean(dim=0)
dense_parts[i] = estimated
# 按 gate_prob 加权求和各专家输出
estimated_dense = 0
for i in range(num_experts):
estimated_dense += gate_prob[i] * dense_parts[i]
return estimated_dense
class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM):
"""
自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块,
以实现 dense backward 功能。
配置类:DenseBackwardOLMoEConfig
"""
config_class = DenseBackwardOLMoEConfig
base_model_prefix = "olmoe"
def __init__(self, config):
# 首先调用父类初始化方法
super().__init__(config)
# 不要尝试重新赋值self,而是从预训练模型加载并更新当前模型
pretrained_model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", torch_dtype=torch.bfloat16)
# 复制预训练模型的状态到当前模型
self.config = pretrained_model.config
self.model = pretrained_model.model
self.vocab_size = pretrained_model.vocab_size
self.router_aux_loss_coef = pretrained_model.router_aux_loss_coef
self.num_experts = pretrained_model.num_experts
self.lm_head = pretrained_model.lm_head
# 遍历模型中所有 decoder 层,替换每个 OlmoeSparseMoeBlock 为 DenseBackward 版本
# 此处假设官方模型在 self.model.layers 中组织 decoder 层,
# 且每层中 mlp 模块包含属性 sparse_moe_block。
for layer in self.model.layers:
if hasattr(layer.mlp, "gate"):
print("111")
orig_block = layer.mlp
# 通过直接复制原版属性创建新的块
new_block = DenseBackwardOlmoeSparseMoeBlock(config) # 或其他适当参数
# 然后手动复制需要共享的属性:
new_block.gate = orig_block.gate
new_block.experts = orig_block.experts
new_block.num_experts = orig_block.num_experts
new_block.top_k = orig_block.top_k
new_block.norm_topk_prob = orig_block.norm_topk_prob
layer.mlp = new_block
print(type(layer.mlp))
# 在调用post_init()前
test_param = self.model.layers[0].mlp.experts[0].up_proj.weight.data[0, 0].item()
print(f"权重示例值(前): {test_param}")
self.post_init()
# 在调用post_init()后
test_param_after = self.model.layers[0].mlp.experts[0].up_proj.weight.data[0, 0].item()
print(f"权重示例值(后): {test_param_after}")
def main():
config = DenseBackwardOLMoEConfig( # 官方模型参数
model_marker="DenseBackward_olmoe_marker",
torch_dtype="bfloat16"
)
# 创建自定义模型实例
model = DenseBackwardOLMoEForCausalLM(config)
print(type(model))
print(type(model.model))
print(type(model.model.layers[0]))
print(type(model.model.layers[0].mlp))
print(type(model.model.layers[0].mlp.experts))
if __name__ == "__main__":
main() |