from transformers.configuration_utils import PretrainedConfig import sys from transformers import ( AutoConfig, AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM, PreTrainedModel, ) from .attrdict_config import AttrDict class VisionConfig(PretrainedConfig): model_type = "vision" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class AlignerConfig(PretrainedConfig): model_type = "aligner" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenVisionConfig(PretrainedConfig): model_type = "gen_vision" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenAlignerConfig(PretrainedConfig): model_type = "gen_aligner" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class GenHeadConfig(PretrainedConfig): model_type = "gen_head" cls: str = "" params: AttrDict = {} def __init__(self, **kwargs): super().__init__(**kwargs) self.cls = kwargs.get("cls", "") if not isinstance(self.cls, str): self.cls = self.cls.__name__ self.params = AttrDict(kwargs.get("params", {})) class MultiModalityConfig(PretrainedConfig): model_type = "multi_modality" vision_config: VisionConfig aligner_config: AlignerConfig gen_vision_config: GenVisionConfig gen_aligner_config: GenAlignerConfig gen_head_config: GenHeadConfig language_config: LlamaConfig def __init__(self, **kwargs): super().__init__(**kwargs) vision_config = kwargs.get("vision_config", {}) self.vision_config = VisionConfig(**vision_config) aligner_config = kwargs.get("aligner_config", {}) self.aligner_config = AlignerConfig(**aligner_config) gen_vision_config = kwargs.get("gen_vision_config", {}) self.gen_vision_config = GenVisionConfig(**gen_vision_config) gen_aligner_config = kwargs.get("gen_aligner_config", {}) self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config) gen_head_config = kwargs.get("gen_head_config", {}) self.gen_head_config = GenHeadConfig(**gen_head_config) language_config = kwargs.get("language_config", {}) if isinstance(language_config, LlamaConfig): self.language_config = language_config else: self.language_config = LlamaConfig(**language_config)