from transformers import PretrainedConfig class EmuruConfig(PretrainedConfig): model_type = "emuru" def __init__(self, t5_name_or_path='google-t5/t5-large', vae_name_or_path='blowing-up-groundhogs/emuru_vae', tokenizer_name_or_path='google/byt5-small', slices_per_query=1, vae_channels=1, **kwargs): super().__init__(**kwargs) self.t5_name_or_path = t5_name_or_path self.vae_name_or_path = vae_name_or_path self.tokenizer_name_or_path = tokenizer_name_or_path self.slices_per_query = slices_per_query self.vae_channels = vae_channels