from transformers import PretrainedConfig
from typing import List


class MoonshineConfig(PretrainedConfig):
    model_type = "moonshine"

    def __init__(
        self,
        dim: int = 288,
        inner_dim: int = None,
        enc_depth: int = 8,
        dec_depth: int = 8,
        n_head: int = 8,
        dec_voc_size: int = 32768,
        enc_ff_swiglu: bool = False,
        dec_ff_swiglu: bool = True,
        **kwargs
    ):
        if inner_dim is None:
            inner_dim = dim
        if inner_dim % n_head != 0:
            raise ValueError("`inner dim` must be divisible by `n_head`")
        self.dim = dim
        self.inner_dim = inner_dim
        self.enc_depth = enc_depth
        self.dec_depth = dec_depth
        self.n_head = n_head
        self.dec_voc_size = dec_voc_size
        self.enc_ff_swiglu = enc_ff_swiglu
        self.dec_ff_swiglu = dec_ff_swiglu
        super().__init__(**kwargs)