# my_custom_olmoe/modeling_custom.py import torch import torch.nn as nn import torch.nn.functional as F # 导入官方实现(注意根据你的 transformers 版本调整导入路径) from transformers.models.olmoe.modeling_olmoe import OlmoeForCausalLM, OlmoeSparseMoeBlock, OlmoeMLP from .configuration_densebackward_olmoe import DenseBackwardOLMoEConfig class DenseBackwardOlmoeSparseMoeBlock(OlmoeSparseMoeBlock): def forward(self, hidden_states: torch.Tensor): batch_size, seq_length, hidden_dim = hidden_states.shape flat_hidden = hidden_states.view(-1, hidden_dim) # (B*seq_len, hidden_dim) # 计算路由 logits 和 routing 权重 router_logits = self.gate(flat_hidden) # (B*seq_len, num_experts) routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float) # (B*seq_len, num_experts) # Top-k 选择 routing_weights_topk, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1) if self.norm_topk_prob: routing_weights_topk = routing_weights_topk / routing_weights_topk.sum(dim=-1, keepdim=True) routing_weights_topk = routing_weights_topk.to(flat_hidden.dtype) # ---------- 稀疏计算部分 ---------- # 初始化稀疏输出 sparse_output = torch.zeros((flat_hidden.size(0), hidden_dim), dtype=flat_hidden.dtype, device=flat_hidden.device) # 存储所有激活信息的数据结构 num_tokens = flat_hidden.size(0) all_activated_outputs = {} # {expert_idx: {token_idx: output_tensor}} all_routing_indices = {} # {expert_idx: [token_indices]} token_activated_experts = {} # {token_idx: [activated_expert_indices]} # one-hot 编码 top-k 专家 expert_mask = F.one_hot(selected_experts, num_classes=self.num_experts) # (B*seq_len, top_k, num_experts) expert_mask = expert_mask.permute(2, 1, 0) # (num_experts, top_k, B*seq_len) # 稀疏计算,同时记录激活情况 for expert_idx in range(self.num_experts): expert_layer = self.experts[expert_idx] idx, top_x = torch.where(expert_mask[expert_idx]) if top_x.numel() > 0: current_state = flat_hidden[top_x] # (n, hidden_dim) current_output = expert_layer(current_state) # (n, hidden_dim) weight = routing_weights_topk[top_x, idx].unsqueeze(-1) # (n, 1) weighted_output = current_output * weight sparse_output.index_add_(0, top_x, weighted_output.to(flat_hidden.dtype)) # 记录该专家激活的token和对应输出 all_activated_outputs[expert_idx] = {} all_routing_indices[expert_idx] = top_x.tolist() for pos, token_idx in enumerate(top_x.tolist()): # 记录该专家对该token的输出 all_activated_outputs[expert_idx][token_idx] = current_output[pos] # 记录该token激活的专家 if token_idx not in token_activated_experts: token_activated_experts[token_idx] = [] token_activated_experts[token_idx].append(expert_idx) # ---------- 稀疏计算结束 ---------- # ---------- Dense估计部分 ---------- # 将activated_experts 转换为list格式,与路由权重匹配 all_routing = selected_experts.tolist() # 长度为 (B*seq_len) # 使用已激活信息估计dense输出 dense_outputs = [] for token_idx in range(num_tokens): # 获取当前token的激活专家列表 activated = all_routing[token_idx] if token_idx in token_activated_experts else [] # 估计dense输出(只使用已经计算过的专家输出) dense_est = self.estimate_dense_output_efficient( token_idx=token_idx, activated=activated, gate_prob=routing_weights[token_idx], all_activated_outputs=all_activated_outputs, all_routing_indices=all_routing_indices, token_activated_experts=token_activated_experts ) dense_outputs.append(dense_est.unsqueeze(0)) dense_outputs = torch.cat(dense_outputs, dim=0) # (B*seq_len, hidden_dim) # ---------- Dense估计结束 ---------- # 使用直通梯度技巧 final_flat = sparse_output.detach() + (dense_outputs - dense_outputs.detach()) final_output = final_flat.view(batch_size, seq_length, hidden_dim) return final_output, router_logits def estimate_dense_output_efficient(self, token_idx, activated, gate_prob, all_activated_outputs, all_routing_indices, token_activated_experts): """ 优化版本的dense输出估计,只使用已计算的专家输出 """ num_experts = gate_prob.size(0) dense_parts = {} # 对于激活的专家,直接使用其输出 for expert_idx in activated: if expert_idx in all_activated_outputs and token_idx in all_activated_outputs[expert_idx]: dense_parts[expert_idx] = all_activated_outputs[expert_idx][token_idx] # 对于未激活的专家,使用其他token的激活输出估计 non_activated = [i for i in range(num_experts) if i not in activated] for expert_idx in non_activated: # 如果该专家没有被任何token激活,跳过 if expert_idx not in all_routing_indices or not all_routing_indices[expert_idx]: # 使用零向量或平均值作为估计 dense_parts[expert_idx] = torch.zeros_like(next(iter(dense_parts.values()))) if dense_parts else 0 continue # 找出激活了该专家的token,并且这些token也激活了当前token激活的某些专家 candidate_tokens = [] for other_token in all_routing_indices[expert_idx]: # 检查other_token是否与当前token共享某些激活专家 if other_token in token_activated_experts: common_experts = set(activated) & set(token_activated_experts[other_token]) if common_experts: candidate_tokens.append(other_token) # 如果找到了候选token,使用它们的输出平均值 if candidate_tokens: expert_outputs = [all_activated_outputs[expert_idx][t] for t in candidate_tokens] estimated = torch.stack(expert_outputs).mean(dim=0) else: # 找不到合适的候选,使用所有激活了该专家的token expert_outputs = [all_activated_outputs[expert_idx][t] for t in all_routing_indices[expert_idx]] estimated = torch.stack(expert_outputs).mean(dim=0) dense_parts[expert_idx] = estimated # 按路由权重加权求和 estimated_dense = 0 for expert_idx in range(num_experts): if expert_idx in dense_parts: estimated_dense += gate_prob[expert_idx] * dense_parts[expert_idx] return estimated_dense class DenseBackwardOLMoEForCausalLM(OlmoeForCausalLM): """ 自定义的 Olmoe ForCausalLM 模型,使用新的 DenseBackwardOlmoeSparseMoeBlock 替换原版的 MoE 模块, 以实现 dense backward 功能。 配置类:DenseBackwardOLMoEConfig """ config_class = DenseBackwardOLMoEConfig base_model_prefix = "olmoe" def __init__(self, config): # 首先调用父类初始化方法 super().__init__(config) # 不要尝试重新赋值self,而是从预训练模型加载并更新当前模型 pretrained_model = OlmoeForCausalLM.from_pretrained("allenai/OLMoE-1B-7B-0924", torch_dtype=torch.bfloat16) # 复制预训练模型的状态到当前模型 self.config = pretrained_model.config self.model = pretrained_model.model self.vocab_size = pretrained_model.vocab_size self.router_aux_loss_coef = pretrained_model.router_aux_loss_coef self.num_experts = pretrained_model.num_experts self.lm_head = pretrained_model.lm_head # 遍历模型中所有 decoder 层,替换每个 OlmoeSparseMoeBlock 为 DenseBackward 版本 # 此处假设官方模型在 self.model.layers 中组织 decoder 层, # 且每层中 mlp 模块包含属性 sparse_moe_block。 for layer in self.model.layers: if hasattr(layer.mlp, "gate"): print("111") orig_block = layer.mlp # 通过直接复制原版属性创建新的块 new_block = DenseBackwardOlmoeSparseMoeBlock(config) # 或其他适当参数 # 然后手动复制需要共享的属性: new_block.gate = orig_block.gate new_block.experts = orig_block.experts new_block.num_experts = orig_block.num_experts new_block.top_k = orig_block.top_k new_block.norm_topk_prob = orig_block.norm_topk_prob layer.mlp = new_block print(type(layer.mlp)) # 在调用post_init()前 test_param = self.model.layers[0].mlp.experts[0].up_proj.weight.data[0, 0].item() print(f"权重示例值(前): {test_param}") self.post_init() # 在调用post_init()后 test_param_after = self.model.layers[0].mlp.experts[0].up_proj.weight.data[0, 0].item() print(f"权重示例值(后): {test_param_after}") def main(): config = DenseBackwardOLMoEConfig( # 官方模型参数 model_marker="DenseBackward_olmoe_marker", torch_dtype="bfloat16" ) # 创建自定义模型实例 model = DenseBackwardOLMoEForCausalLM(config) print(type(model)) print(type(model.model)) print(type(model.model.layers[0])) print(type(model.model.layers[0].mlp)) print(type(model.model.layers[0].mlp.experts)) if __name__ == "__main__": main()