# Copyright (c) Kyutai, all rights reserved. # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. """Retrieves the pretrained models for Moshi and Mimi.""" from pathlib import Path from safetensors.torch import load_model import torch from moshi.models.compression import MimiModel from moshi.models.lm import LMModel from moshi.modules import SEANetEncoder, SEANetDecoder, transformer from moshi.quantization import SplitResidualVectorQuantizer SAMPLE_RATE = 24000 FRAME_RATE = 12.5 TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model' MOSHI_NAME = 'model.safetensors' MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors' DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16' _seanet_kwargs = { "channels": 1, "dimension": 512, "causal": True, "n_filters": 64, "n_residual_layers": 1, "activation": "ELU", "compress": 2, "dilation_base": 2, "disable_norm_outer_blocks": 0, "kernel_size": 7, "residual_kernel_size": 3, "last_kernel_size": 3, # We train using weight_norm but then the weights are pre-processed for inference so # that we can use a normal convolution. "norm": "none", "pad_mode": "constant", "ratios": [8, 6, 5, 4], "true_skip": True, } _quantizer_kwargs = { "dimension": 256, "n_q": 32, "bins": 2048, "input_dimension": _seanet_kwargs["dimension"], "output_dimension": _seanet_kwargs["dimension"], } _transformer_kwargs = { "d_model": _seanet_kwargs["dimension"], "num_heads": 8, "num_layers": 8, "causal": True, "layer_scale": 0.01, "context": 250, "conv_layout": True, "max_period": 10000, "gating": "none", "norm": "layer_norm", "positional_embedding": "rope", "dim_feedforward": 2048, "input_dimension": _seanet_kwargs["dimension"], "output_dimensions": [_seanet_kwargs["dimension"]], } _lm_kwargs = { "dim": 4096, "text_card": 32000, "existing_text_padding_id": 3, "n_q": 16, "dep_q": 8, "card": _quantizer_kwargs["bins"], "num_heads": 32, "num_layers": 32, "hidden_scale": 4.125, "causal": True, "layer_scale": None, "context": 3000, "max_period": 10000, "gating": "silu", "norm": "rms_norm_f32", "positional_embedding": "rope", "depformer_dim": 1024, "depformer_dim_feedforward": int(4.125 * 1024), "depformer_num_heads": 16, "depformer_num_layers": 6, "depformer_causal": True, "depformer_layer_scale": None, "depformer_multi_linear": True, "depformer_context": 8, "depformer_max_period": 10000, "depformer_gating": "silu", "depformer_pos_emb": "none", "depformer_weights_per_step": True, "delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1], } def _is_safetensors(path: Path | str) -> bool: return Path(path).suffix in (".safetensors", ".sft", ".sfts") def get_mimi(filename: str | Path, device: torch.device | str = 'cpu') -> MimiModel: """Return a pretrained Mimi model.""" encoder = SEANetEncoder(**_seanet_kwargs) decoder = SEANetDecoder(**_seanet_kwargs) encoder_transformer = transformer.ProjectedTransformer( device=device, **_transformer_kwargs ) decoder_transformer = transformer.ProjectedTransformer( device=device, **_transformer_kwargs ) quantizer = SplitResidualVectorQuantizer( **_quantizer_kwargs, ) model = MimiModel( encoder, decoder, quantizer, channels=1, sample_rate=SAMPLE_RATE, frame_rate=FRAME_RATE, encoder_frame_rate=SAMPLE_RATE / encoder.hop_length, causal=True, resample_method="conv", encoder_transformer=encoder_transformer, decoder_transformer=decoder_transformer, ).to(device=device) model.eval() if _is_safetensors(filename): load_model(model, filename) else: pkg = torch.load(filename, "cpu") model.load_state_dict(pkg["model"]) model.set_num_codebooks(8) return model def get_moshi_lm(filename: str | Path, device: torch.device | str = 'cpu') -> LMModel: dtype = torch.bfloat16 model = LMModel( device=device, dtype=dtype, **_lm_kwargs, ).to(device=device, dtype=dtype) model.eval() if _is_safetensors(filename): load_model(model, filename) else: pkg = torch.load( filename, "cpu", ) model.load_state_dict(pkg["fsdp_best_state"]["model"]) return model