File size: 919 Bytes
52bb403 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 |
# my_custom_olmoe/configuration_custom.py
# 注意:根据你的 transformers 版本,导入官方 OLMoE 配置的路径可能需要调整
from transformers.models.olmoe.configuration_olmoe import OlmoeConfig
class DenseBackwardOLMoEConfig(OlmoeConfig):
model_type = "DenseBackward_olmoe" # 这里覆盖 model_type 字段,便于后续识别
# 添加auto_map用于支持AutoClass
auto_map = {
"AutoConfig": "configuration_custom.DenseBackwardOLMoEConfig",
"AutoModelForCausalLM": "modeling_custom.DenseBackwardOLMoEForCausalLM"
}
def __init__(self, model_marker="DenseBackward_olmoe_marker", **kwargs):
super().__init__(**kwargs)
self.model_marker = model_marker
self.intermediate_size= 1024
#test
def main():
config = DenseBackwardOLMoEConfig(model_marker="DenseBackward_olmoe_marker")
print(config)
if __name__ == "__main__":
main() |