from transformers import PretrainedConfig | |
class EmuruConfig(PretrainedConfig): | |
model_type = "emuru" # Unique identifier for your model | |
def __init__(self, | |
t5_config='google-t5/t5-large', | |
vae_config='blowing-up-groundhogs/emuru_vae', | |
tokenizer_config='google/byt5-small', | |
slices_per_query=1, | |
vae_channels=1, | |
**kwargs): | |
super().__init__(**kwargs) | |
self.t5_config = t5_config | |
self.vae_config = vae_config | |
self.tokenizer_config = tokenizer_config | |
self.slices_per_query = slices_per_query | |
self.vae_channels = vae_channels | |