File size: 919 Bytes
52bb403
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# my_custom_olmoe/configuration_custom.py

# 注意:根据你的 transformers 版本,导入官方 OLMoE 配置的路径可能需要调整
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig

class DenseBackwardOLMoEConfig(OlmoeConfig):
    model_type = "DenseBackward_olmoe"  # 这里覆盖 model_type 字段,便于后续识别
    
    # 添加auto_map用于支持AutoClass
    auto_map = {
        "AutoConfig": "configuration_custom.DenseBackwardOLMoEConfig",
        "AutoModelForCausalLM": "modeling_custom.DenseBackwardOLMoEForCausalLM"
    }

    def __init__(self, model_marker="DenseBackward_olmoe_marker", **kwargs):
        super().__init__(**kwargs)
        self.model_marker = model_marker
        self.intermediate_size= 1024
#test
def main():
    config = DenseBackwardOLMoEConfig(model_marker="DenseBackward_olmoe_marker")
    print(config)

if __name__ == "__main__":
    main()