tezuesh commited on
Commit
22d5f88
·
verified ·
1 Parent(s): 828c20a

Upload folder using huggingface_hub

Browse files
Dockerfile ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
2
+
3
+ WORKDIR /app
4
+
5
+ # Install system dependencies
6
+ RUN apt-get update && \
7
+ DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
8
+ build-essential \
9
+ python3-dev \
10
+ git \
11
+ ffmpeg \
12
+ libsndfile1 \
13
+ && rm -rf /var/lib/apt/lists/*
14
+
15
+ # Pre-install critical dependencies
16
+ RUN pip install --no-cache-dir --upgrade pip setuptools wheel
17
+
18
+ # Copy entire directory
19
+ COPY . .
20
+
21
+ # Install package in editable mode
22
+ RUN pip install -e .
23
+
24
+ # Environment variables
25
+ ENV MODEL_PATH=/app/models \
26
+ PYTHONUNBUFFERED=1
27
+
28
+ EXPOSE 8000
29
+
30
+ CMD ["python", "server.py"]
README.md ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - any-to-any
5
+ - omega
6
+ - omegalabs
7
+ - bittensor
8
+ - agi
9
+ ---
10
+
11
+ This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
12
+
13
+ Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
14
+
15
+ ## License
16
+
17
+ This project is licensed under the MIT License - see the LICENSE file for details.
18
+
19
+ ## Links
20
+
21
+ - GitHub Repository: [omegalabs-anytoany-bittensor](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor)
22
+ - OMEGA Labs on X: [@omegalabsai](https://x.com/omegalabsai)
23
+
24
+ ## Contributing
25
+
26
+ Contributions are welcome! Please feel free to submit a Pull Request.
27
+
28
+ ## Support
29
+
30
+ For support and questions, please:
31
+ 1. Open an issue on GitHub
32
+ 2. Follow OMEGA Labs on X [@omegalabsai](https://x.com/omegalabsai)
chunk_silence.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydub import AudioSegment
3
+ from pydub.silence import detect_silence
4
+ import glob
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ def detect_and_trim_silence(audio_array, frame_rate, min_silence_duration=1000, silence_threshold=-40):
9
+ # Load the audio file
10
+ audio_array = audio_array.detach().cpu().numpy()
11
+ audio_bytes = (audio_array * 32767).astype(np.int16).tobytes()
12
+
13
+ # Create AudioSegment from raw audio bytes
14
+ audio = AudioSegment(
15
+ data=audio_bytes,
16
+ sample_width=2, # 16-bit audio = 2 bytes
17
+ frame_rate=frame_rate,
18
+ channels=1 # Mono audio
19
+ )
20
+
21
+ # Detect silence
22
+ silence_intervals = detect_silence(
23
+ audio,
24
+ min_silence_len=min_silence_duration,
25
+ silence_thresh=silence_threshold
26
+ )
27
+
28
+ # Convert milliseconds to seconds
29
+ silence_intervals_seconds = [(start / 1000, end / 1000) for start, end in silence_intervals]
30
+
31
+ first_silence_end = silence_intervals_seconds[0][1]
32
+ # Create audio from first silence end to end of audio
33
+ trimmed_audio = audio_array[:,:,int(first_silence_end * frame_rate):] # Slice audio from first silence end to the end
34
+
35
+ # trimmed_audio.export(output_path, format="wav")
36
+ return trimmed_audio
37
+
38
+ def process_all_audio_files(root_dir, output_dir):
39
+ # Use glob to find all .wav files in the directory and its subdirectories
40
+ wav_files = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
41
+ for input_path in tqdm(wav_files, desc="Processing audio files"):
42
+ relative_path = os.path.relpath(input_path, root_dir)
43
+ output_dir = os.path.join(output_dir, os.path.dirname(relative_path))
44
+ os.makedirs(output_dir, exist_ok=True)
45
+ output_path = os.path.join(output_dir, os.path.basename(input_path))
46
+ detect_and_trim_silence(input_path, output_path)
47
+
48
+ # Use the function to process all audio files
49
+ if __name__ == "__main__":
50
+ root_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/audio/"
51
+ output_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/trimmed/"
52
+ process_all_audio_files(root_directory, output_directory)
53
+
config.yaml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ model:
2
+ name: moshi
3
+ version: 1.0
4
+ description: "Moshi Pretrained Model"
5
+ author: "Tezuesh"
6
+ license: "MIT"
hotkey.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 5FeqmebkCWfepQPgSkrEHRwtpUmHGASF4BNERZDs9pvKFtcD
inference.py ADDED
@@ -0,0 +1,244 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ import torchaudio
4
+ import sentencepiece
5
+ import logging
6
+ from pathlib import Path
7
+ from moshi.models import loaders, LMGen
8
+
9
+ logging.basicConfig(level=logging.INFO)
10
+ logger = logging.getLogger(__name__)
11
+
12
+ class InferenceRecipe:
13
+ """Handles model inference for the Any-to-Any model."""
14
+
15
+ def __init__(self, model_path: str, device: str='cuda'):
16
+ """Initialize the model.
17
+
18
+ Args:
19
+ model_path (str): Path to model directory with pre-downloaded files
20
+ device (str): Device to run on ('cuda' or 'cpu')
21
+ """
22
+ self.device = torch.device(device)
23
+ self.model_path = Path(model_path)
24
+
25
+ # Set sample rate and frame rate
26
+ self.sample_rate = 24000 # Based on model config in loaders.py
27
+ self.frame_rate = 12.5 # Based on model config in loaders.py
28
+
29
+ # Initialize all model components
30
+ logger.info(f"Initializing models from {model_path}")
31
+ self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models()
32
+ logger.info("Model initialization complete")
33
+
34
+ def _initialize_models(self):
35
+ """Initialize all required model components."""
36
+ print("Initializing models...")
37
+
38
+ try:
39
+ # Load MIMI model for encoding/decoding
40
+ mimi_path = self.model_path / loaders.MIMI_NAME
41
+ if not mimi_path.exists():
42
+ raise RuntimeError(f"MIMI model not found at {mimi_path}")
43
+ logger.info(f"Loading MIMI model from {mimi_path}")
44
+ mimi = loaders.get_mimi(str(mimi_path), device=self.device)
45
+ mimi.set_num_codebooks(8)
46
+
47
+ # Load text tokenizer
48
+ tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME
49
+ if not tokenizer_path.exists():
50
+ raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}")
51
+ logger.info(f"Loading text tokenizer from {tokenizer_path}")
52
+ text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
53
+
54
+ # Load language model
55
+ moshi_path = self.model_path / loaders.MOSHI_NAME
56
+ if not moshi_path.exists():
57
+ raise RuntimeError(f"Language model not found at {moshi_path}")
58
+ logger.info(f"Loading language model from {moshi_path}")
59
+ moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device)
60
+ lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
61
+
62
+ return mimi, text_tokenizer, lm_gen
63
+
64
+ except Exception as e:
65
+ logger.error(f"Model initialization failed: {str(e)}")
66
+ raise
67
+
68
+ def _load_audio(self, audio_array: np.ndarray, sample_rate: int):
69
+ """Load and preprocess audio."""
70
+ try:
71
+ # Convert to tensor
72
+ wav = torch.from_numpy(audio_array).float().unsqueeze(0).unsqueeze(0)
73
+
74
+ # Resample if needed
75
+ if sample_rate != self.sample_rate:
76
+ logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
77
+ wav = torchaudio.transforms.Resample(
78
+ orig_freq=sample_rate,
79
+ new_freq=self.sample_rate
80
+ )(wav)
81
+
82
+ # Ensure frame alignment
83
+ frame_size = int(self.sample_rate / self.frame_rate)
84
+ orig_length = wav.shape[-1]
85
+ wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
86
+ if wav.shape[-1] != orig_length:
87
+ logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
88
+
89
+ return wav
90
+
91
+ except Exception as e:
92
+ logger.error(f"Audio loading failed: {str(e)}")
93
+ raise
94
+
95
+ def _pad_codes(self, all_codes, time_seconds=30):
96
+ """Pad codes to minimum length if needed."""
97
+ try:
98
+ min_frames = int(time_seconds * self.frame_rate)
99
+ frame_size = int(self.sample_rate / self.frame_rate)
100
+
101
+ if len(all_codes) < min_frames:
102
+ frames_to_add = min_frames - len(all_codes)
103
+ logger.info(f"Padding {frames_to_add} frames to reach minimum length")
104
+ with torch.no_grad(), self.mimi.streaming(batch_size=1):
105
+ chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
106
+ for _ in range(frames_to_add):
107
+ additional_code = self.mimi.encode(chunk)
108
+ all_codes.append(additional_code)
109
+
110
+ return all_codes
111
+
112
+ except Exception as e:
113
+ logger.error(f"Code padding failed: {str(e)}")
114
+ raise
115
+
116
+ def _encode_audio(self, wav: torch.Tensor):
117
+ """Convert audio to codes."""
118
+ try:
119
+ frame_size = int(self.sample_rate / self.frame_rate)
120
+ all_codes = []
121
+
122
+ with torch.no_grad(), self.mimi.streaming(batch_size=1):
123
+ for offset in range(0, wav.shape[-1], frame_size):
124
+ frame = wav[:, :, offset: offset + frame_size]
125
+ codes = self.mimi.encode(frame.to(self.device))
126
+ assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}"
127
+ all_codes.append(codes)
128
+
129
+ logger.info(f"Encoded {len(all_codes)} frames")
130
+ return all_codes
131
+
132
+ except Exception as e:
133
+ logger.error(f"Audio encoding failed: {str(e)}")
134
+ raise
135
+
136
+ def _warmup(self):
137
+ """Run a warmup pass."""
138
+ try:
139
+ frame_size = int(self.sample_rate / self.frame_rate)
140
+ chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
141
+ codes = self.mimi.encode(chunk)
142
+
143
+ with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
144
+ tokens = self.lm_gen.step(codes[:, :, 0:1])
145
+ if tokens is not None:
146
+ _ = self.mimi.decode(tokens[:, 1:])
147
+
148
+ torch.cuda.synchronize()
149
+ logger.info("Warmup pass completed")
150
+
151
+ except Exception as e:
152
+ logger.error(f"Warmup failed: {str(e)}")
153
+ raise
154
+
155
+ def _generate(self, all_codes):
156
+ """Generate audio and text from codes."""
157
+ try:
158
+ out_wav_chunks = []
159
+ text_output = []
160
+
161
+ with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
162
+ for i, code in enumerate(all_codes):
163
+ assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}"
164
+ tokens_out = self.lm_gen.step(code.to(self.device))
165
+
166
+ if tokens_out is not None:
167
+ # Generate audio
168
+ wav_chunk = self.mimi.decode(tokens_out[:, 1:])
169
+ out_wav_chunks.append(wav_chunk)
170
+
171
+ # Generate text if available
172
+ text_token = tokens_out[0, 0, 0].item()
173
+ if text_token not in (0, 3):
174
+ _text = self.text_tokenizer.id_to_piece(text_token)
175
+ _text = _text.replace("▁", " ")
176
+ text_output.append(_text)
177
+
178
+ if (i + 1) % 100 == 0:
179
+ logger.info(f"Processed {i + 1}/{len(all_codes)} frames")
180
+
181
+ wav = torch.cat(out_wav_chunks, dim=-1)
182
+ text = ''.join(text_output)
183
+
184
+ logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text")
185
+ return wav, text
186
+
187
+ except Exception as e:
188
+ logger.error(f"Generation failed: {str(e)}")
189
+ raise
190
+
191
+ def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict:
192
+ """Run inference on input audio.
193
+
194
+ Args:
195
+ audio_array (np.ndarray): Input audio as numpy array
196
+ sample_rate (int): Sample rate of input audio
197
+
198
+ Returns:
199
+ dict: Contains generated audio array and optional transcribed text
200
+ """
201
+ try:
202
+ logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate}Hz")
203
+
204
+ # Load and preprocess audio
205
+ wav = self._load_audio(audio_array, sample_rate)
206
+ wav = wav.to(self.device)
207
+
208
+ # Convert to codes
209
+ all_codes = self._encode_audio(wav)
210
+ all_codes = self._pad_codes(all_codes)
211
+
212
+ # Warmup pass
213
+ self._warmup()
214
+
215
+ # Generate output
216
+ out_wav, text = self._generate(all_codes)
217
+
218
+ # Convert output to numpy
219
+ output = out_wav.cpu().numpy().squeeze()
220
+
221
+ logger.info("Inference completed successfully")
222
+ return {
223
+ "audio": output,
224
+ "text": text
225
+ }
226
+
227
+ except Exception as e:
228
+ logger.error(f"Inference failed: {str(e)}")
229
+ raise
230
+
231
+ if __name__ == "__main__":
232
+ # Example usage
233
+ import librosa
234
+
235
+ # Initialize model
236
+ model = InferenceRecipe("/path/to/models", device="cuda")
237
+
238
+ # Load test audio
239
+ audio, sr = librosa.load("test.wav", sr=None)
240
+
241
+ # Run inference
242
+ result = model.inference(audio, sr)
243
+ print(f"Generated {len(result['audio'])} samples of audio")
244
+ print(f"Generated text: {result['text']}")
mimi_tokenizer.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from moshi import models
2
+ loaders = models.loaders
3
+ from huggingface_hub import hf_hub_download
4
+ import torch
5
+ from pydub import AudioSegment
6
+ import numpy as np
7
+
8
+ MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
9
+ DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
10
+
11
+
12
+ device = "cuda" if torch.cuda.is_available() else "cpu"
13
+ mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
14
+ mimi = loaders.get_mimi(mimi_weight, device=device)
15
+
16
+ def encode_audio(mimi, wav, device):
17
+ frame_size = int(mimi.sample_rate / mimi.frame_rate)
18
+ all_codes = []
19
+ with torch.no_grad(), mimi.streaming(batch_size=1):
20
+ for offset in range(0, wav.shape[-1], frame_size):
21
+ frame = wav[:, :, offset: offset + frame_size]
22
+ codes = mimi.encode(frame.to(device))
23
+ assert codes.shape[-1] == 1, codes.shape
24
+ all_codes.append(codes)
25
+
26
+ return all_codes
27
+
28
+
29
+ def load_audio(wav_path, mimi):
30
+ audio = AudioSegment.from_wav(wav_path)
31
+ samples = np.array(audio.get_array_of_samples())
32
+ samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
33
+ wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0)
34
+
35
+ if audio.frame_rate != mimi.sample_rate:
36
+ wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False)
37
+
38
+ frame_size = int(mimi.sample_rate / mimi.frame_rate)
39
+ wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
40
+
41
+ return wav
42
+
43
+
44
+
45
+
46
+
47
+
48
+
models/model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5b835c664f3830bf453808cbca9bfbcc9de332c328cc01cbffdfbaba2a8838a7
3
+ size 15375500136
models/tokenizer-e351c8d8-checkpoint125.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09b782f0629851a271227fb9d36db65c041790365f11bbe5d3d59369cf863f50
3
+ size 384644900
models/tokenizer_spm_32k_3.model ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:78d4336533ddc26f9acf7250d7fb83492152196c6ea4212c841df76933f18d2d
3
+ size 552778
moshi/chunk_silence.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from pydub import AudioSegment
3
+ from pydub.silence import detect_silence
4
+ import glob
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+
8
+ def detect_and_trim_silence(audio_array, frame_rate, min_silence_duration=1000, silence_threshold=-40):
9
+ # Load the audio file
10
+ audio_array = audio_array.detach().cpu().numpy()
11
+ audio_bytes = (audio_array * 32767).astype(np.int16).tobytes()
12
+
13
+ # Create AudioSegment from raw audio bytes
14
+ audio = AudioSegment(
15
+ data=audio_bytes,
16
+ sample_width=2, # 16-bit audio = 2 bytes
17
+ frame_rate=frame_rate,
18
+ channels=1 # Mono audio
19
+ )
20
+
21
+ # Detect silence
22
+ silence_intervals = detect_silence(
23
+ audio,
24
+ min_silence_len=min_silence_duration,
25
+ silence_thresh=silence_threshold
26
+ )
27
+
28
+ # Convert milliseconds to seconds
29
+ silence_intervals_seconds = [(start / 1000, end / 1000) for start, end in silence_intervals]
30
+
31
+ first_silence_end = silence_intervals_seconds[0][1]
32
+ # Create audio from first silence end to end of audio
33
+ trimmed_audio = audio_array[:,:,int(first_silence_end * frame_rate):] # Slice audio from first silence end to the end
34
+
35
+ # trimmed_audio.export(output_path, format="wav")
36
+ return trimmed_audio
37
+
38
+ def process_all_audio_files(root_dir, output_dir):
39
+ # Use glob to find all .wav files in the directory and its subdirectories
40
+ wav_files = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
41
+ for input_path in tqdm(wav_files, desc="Processing audio files"):
42
+ relative_path = os.path.relpath(input_path, root_dir)
43
+ output_dir = os.path.join(output_dir, os.path.dirname(relative_path))
44
+ os.makedirs(output_dir, exist_ok=True)
45
+ output_path = os.path.join(output_dir, os.path.basename(input_path))
46
+ detect_and_trim_silence(input_path, output_path)
47
+
48
+ # Use the function to process all audio files
49
+ if __name__ == "__main__":
50
+ root_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/audio/"
51
+ output_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/trimmed/"
52
+ process_all_audio_files(root_directory, output_directory)
53
+
moshi/models/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """
5
+ Models for the compression model Moshi,
6
+ """
7
+
8
+ # flake8: noqa
9
+ from moshi.models.compression import (
10
+ CompressionModel,
11
+ MimiModel,
12
+ )
13
+ from moshi.models.lm import LMModel, LMGen
14
+ from moshi.models.loaders import get_mimi, get_moshi_lm
moshi/models/compression.py ADDED
@@ -0,0 +1,474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
6
+ # released under the following license.
7
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
8
+ # All rights reserved.
9
+ #
10
+ # This source code is licensed under the license found in the
11
+ # LICENSE file in the root directory of this source tree.
12
+ """Compression models or wrapper around existing models. In particular, provides the implementation
13
+ for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
14
+ """
15
+
16
+ from abc import abstractmethod
17
+ from contextlib import nullcontext
18
+ from dataclasses import dataclass
19
+ import logging
20
+ import typing as tp
21
+
22
+ import torch
23
+ from torch import nn
24
+
25
+
26
+ from moshi.quantization import (
27
+ QuantizedResult,
28
+ BaseQuantizer,
29
+ SplitResidualVectorQuantizer,
30
+ ResidualVectorQuantizer,
31
+ )
32
+ from moshi.modules.resample import ConvDownsample1d, ConvTrUpsample1d
33
+ from moshi.modules.streaming import StreamingModule, State
34
+ from moshi.utils.compile import no_compile, CUDAGraphed
35
+
36
+
37
+ logger = logging.getLogger()
38
+
39
+
40
+ class CompressionModel(StreamingModule[State]):
41
+ """Base API for all compression model that aim at being used as audio tokenizers
42
+ with a language model.
43
+ """
44
+
45
+ @abstractmethod
46
+ def forward(self, x: torch.Tensor) -> QuantizedResult: ...
47
+
48
+ @abstractmethod
49
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
50
+ """See `MimiModel.encode`."""
51
+ ...
52
+
53
+ @abstractmethod
54
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
55
+ """See `MimiModel.decode`."""
56
+ ...
57
+
58
+ @abstractmethod
59
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
60
+ """Decode from the discrete codes to continuous latent space."""
61
+ ...
62
+
63
+ @property
64
+ @abstractmethod
65
+ def channels(self) -> int: ...
66
+
67
+ @property
68
+ @abstractmethod
69
+ def frame_rate(self) -> float: ...
70
+
71
+ @property
72
+ @abstractmethod
73
+ def sample_rate(self) -> int: ...
74
+
75
+ @property
76
+ @abstractmethod
77
+ def cardinality(self) -> int: ...
78
+
79
+ @property
80
+ @abstractmethod
81
+ def num_codebooks(self) -> int: ...
82
+
83
+ @property
84
+ @abstractmethod
85
+ def total_codebooks(self) -> int: ...
86
+
87
+ @abstractmethod
88
+ def set_num_codebooks(self, n: int):
89
+ """Set the active number of codebooks used by the quantizer."""
90
+ ...
91
+
92
+
93
+ @dataclass
94
+ class _MimiState:
95
+ graphed_tr_enc: CUDAGraphed | None
96
+ graphed_tr_dec: CUDAGraphed | None
97
+
98
+ def reset(self):
99
+ pass
100
+
101
+
102
+ class MimiModel(CompressionModel[_MimiState]):
103
+ """Mimi model operating on the raw waveform.
104
+
105
+ Args:
106
+ encoder (nn.Module): Encoder network.
107
+ decoder (nn.Module): Decoder network.
108
+ quantizer (qt.BaseQuantizer): Quantizer network.
109
+ frame_rate (float): Final frame rate of the quantized representatiopn.
110
+ encoder_frame_rate (float): frame rate of the encoder model. Note that if `frame_rate != encopder_frame_rate`,
111
+ the latent will be resampled linearly to match the desired `frame_rate` before and after quantization.
112
+ sample_rate (int): Audio sample rate.
113
+ channels (int): Number of audio channels.
114
+ causal (bool): Whether to use a causal version of the model.
115
+ encoder_transformer (nn.Module or None): optional transformer for the encoder.
116
+ decoder_transformer (nn.Module or None): optional transformer for the decoder.
117
+ resample_method (str): method to use for resampling the latent space before the quantizer.
118
+ upsample_channel_wise_bug (bool): controls whether the upsampling is channel wise.
119
+ Defaults to true to reproduce bug in original implementation.
120
+ freeze_encoder: whether to freeze the encoder weights.
121
+ freeze_quantizer: whether to freeze the quantizer weights.
122
+ freeze_quantizer_level: If positive, freeze the quantizer up to this level.
123
+ torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
124
+ Deactivated by default for training as this is incompatible at the moment with weight norm.
125
+ See https://github.com/pytorch/pytorch/issues/121902
126
+ Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
127
+ """
128
+
129
+ def __init__(
130
+ self,
131
+ encoder: nn.Module,
132
+ decoder: nn.Module,
133
+ quantizer: BaseQuantizer,
134
+ frame_rate: float,
135
+ encoder_frame_rate: float,
136
+ sample_rate: int,
137
+ channels: int,
138
+ causal: bool = False,
139
+ encoder_transformer: tp.Optional[nn.Module] = None,
140
+ decoder_transformer: tp.Optional[nn.Module] = None,
141
+ resample_method: str = "interpolate",
142
+ upsample_channel_wise_bug: bool = True,
143
+ freeze_encoder: bool = False,
144
+ freeze_quantizer: bool = False,
145
+ freeze_quantizer_level: int = -1,
146
+ torch_compile_encoder_decoder: bool = False,
147
+ ):
148
+ super().__init__()
149
+ self.encoder = encoder
150
+ self.decoder = decoder
151
+ self.encoder_transformer = encoder_transformer
152
+ self.decoder_transformer = decoder_transformer
153
+ self.quantizer = quantizer
154
+ self._frame_rate = frame_rate
155
+ self._sample_rate = sample_rate
156
+ self._channels = channels
157
+ self.encoder_frame_rate = encoder_frame_rate
158
+ self.torch_compile_encoder_decoder = torch_compile_encoder_decoder
159
+
160
+ if freeze_encoder:
161
+ for p in self.encoder.parameters():
162
+ p.requires_grad = False
163
+ if self.encoder_transformer is not None:
164
+ for p in self.encoder_transformer.parameters():
165
+ p.requires_grad = False
166
+ for name, p in self.quantizer.named_parameters():
167
+ if name.endswith("input_proj.weight"):
168
+ p.requires_grad = False
169
+ if freeze_quantizer:
170
+ self.quantizer.ema_frozen_(True)
171
+ self.freeze_quantizer = freeze_quantizer
172
+ self.freeze_quantizer_level = (
173
+ freeze_quantizer_level
174
+ if freeze_quantizer_level > 0
175
+ else self.quantizer.num_codebooks
176
+ )
177
+
178
+ # We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
179
+ # which exposes a `dimension` attribute.
180
+ dimension = encoder.dimension
181
+ assert isinstance(
182
+ dimension, int
183
+ ), f"Dimension should be int, got {dimension} of type {type(dimension)}."
184
+ self.dimension = dimension
185
+
186
+ assert resample_method in [
187
+ "interpolate",
188
+ "conv",
189
+ "avg_pool",
190
+ ], f"Invalid resample_method {resample_method}"
191
+ self.resample_method = resample_method
192
+ if encoder_frame_rate != frame_rate:
193
+ assert not (
194
+ causal and resample_method == "interpolate"
195
+ ), "Cannot interpolate with causal model."
196
+ if resample_method in ["conv", "avg_pool"]:
197
+ assert (
198
+ self.encoder_frame_rate > self.frame_rate
199
+ ), "Cannot upsample with conv."
200
+ downsample_stride = self.encoder_frame_rate / self.frame_rate
201
+ assert downsample_stride == int(
202
+ downsample_stride
203
+ ), f"Only integer strides are supported, got {downsample_stride}"
204
+ learnt = resample_method == "conv"
205
+ self.downsample = ConvDownsample1d(
206
+ int(downsample_stride),
207
+ dimension=dimension,
208
+ learnt=learnt,
209
+ causal=causal,
210
+ )
211
+ if freeze_encoder:
212
+ for p in self.downsample.parameters():
213
+ p.requires_grad = False
214
+ self.upsample = ConvTrUpsample1d(
215
+ int(downsample_stride),
216
+ dimension=dimension,
217
+ learnt=learnt,
218
+ causal=causal,
219
+ channel_wise=upsample_channel_wise_bug,
220
+ )
221
+
222
+ def _init_streaming_state(self, batch_size: int) -> _MimiState:
223
+ device = next(self.parameters()).device
224
+ disable = device.type != 'cuda'
225
+ graphed_tr_dec = None
226
+ graphed_tr_enc = None
227
+ if self.encoder_transformer is not None:
228
+ graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
229
+ if self.decoder_transformer is not None:
230
+ graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
231
+ return _MimiState(graphed_tr_enc, graphed_tr_dec)
232
+
233
+ @property
234
+ def channels(self) -> int:
235
+ return self._channels
236
+
237
+ @property
238
+ def frame_rate(self) -> float:
239
+ return self._frame_rate
240
+
241
+ @property
242
+ def sample_rate(self) -> int:
243
+ return self._sample_rate
244
+
245
+ @property
246
+ def total_codebooks(self):
247
+ """Total number of quantizer codebooks available."""
248
+ return self.quantizer.total_codebooks
249
+
250
+ @property
251
+ def num_codebooks(self):
252
+ """Active number of codebooks used by the quantizer."""
253
+ return self.quantizer.num_codebooks
254
+
255
+ def set_num_codebooks(self, n: int):
256
+ """Set the active number of codebooks used by the quantizer."""
257
+ self.quantizer.set_num_codebooks(n)
258
+
259
+ @property
260
+ def cardinality(self):
261
+ """Cardinality of each codebook."""
262
+ return self.quantizer.cardinality
263
+
264
+ def _to_framerate(self, x: torch.Tensor):
265
+ # Convert from the encoder frame rate to the overall framerate.
266
+ _, _, length = x.shape
267
+ frame_rate = self.encoder_frame_rate
268
+ new_frame_rate = self.frame_rate
269
+ if frame_rate == new_frame_rate:
270
+ return x
271
+ if self.resample_method == "interpolate":
272
+ target_length = int(length * new_frame_rate / frame_rate)
273
+ return nn.functional.interpolate(x, size=target_length, mode="linear")
274
+ else:
275
+ return self.downsample(x)
276
+
277
+ def _to_encoder_framerate(self, x: torch.Tensor):
278
+ # Convert from overall framerate to the encoder frame rate.
279
+ _, _, length = x.shape
280
+ frame_rate = self.encoder_frame_rate
281
+ new_frame_rate = self.frame_rate
282
+ if frame_rate == new_frame_rate:
283
+ return x
284
+ if self.resample_method == "interpolate":
285
+ target_length = int(length * new_frame_rate / frame_rate)
286
+ return nn.functional.interpolate(x, size=target_length, mode="linear")
287
+ else:
288
+ return self.upsample(x)
289
+
290
+ @property
291
+ def _context_for_encoder_decoder(self):
292
+ if self.torch_compile_encoder_decoder:
293
+ return nullcontext()
294
+ else:
295
+ return no_compile()
296
+
297
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
298
+ assert x.dim() == 3
299
+ length = x.shape[-1]
300
+ extra_metrics: tp.Dict[str, torch.Tensor] = {}
301
+
302
+ if self.freeze_quantizer:
303
+ if isinstance(self.quantizer, SplitResidualVectorQuantizer):
304
+ self.quantizer.rvq_first.eval()
305
+ for i in range(
306
+ self.freeze_quantizer_level - self.quantizer.n_q_semantic
307
+ ):
308
+ self.quantizer.rvq_rest.vq.layers[i].eval()
309
+ elif isinstance(self.quantizer, ResidualVectorQuantizer):
310
+ for i in range(self.freeze_quantizer_level):
311
+ self.quantizer.vq.layers[i].eval()
312
+ else:
313
+ raise ValueError(f"Unsupported quantizer type {type(self.quantizer)}")
314
+
315
+ with self._context_for_encoder_decoder:
316
+ emb = self.encoder(x)
317
+ if self.encoder_transformer is not None:
318
+ (emb,) = self.encoder_transformer(emb)
319
+ emb = self._to_framerate(emb)
320
+ expected_length = self.frame_rate * length / self.sample_rate
321
+ # Checking that we have the proper length given the advertised frame rate.
322
+ assert abs(emb.shape[-1] - expected_length) < 1, (
323
+ emb.shape[-1],
324
+ expected_length,
325
+ )
326
+
327
+ q_res = self.quantizer(emb, self.frame_rate)
328
+ emb = q_res.x
329
+ emb = self._to_encoder_framerate(emb)
330
+ if self.decoder_transformer is not None:
331
+ (emb,) = self.decoder_transformer(emb)
332
+
333
+ with self._context_for_encoder_decoder:
334
+ out = self.decoder(emb)
335
+
336
+ # remove extra padding added by the encoder and decoder
337
+ assert out.shape[-1] >= length, (out.shape[-1], length)
338
+ out = out[..., :length]
339
+
340
+ q_res.x = out
341
+ q_res.metrics.update(extra_metrics)
342
+ return q_res
343
+
344
+ def _encode_to_unquantized_latent(self, x: torch.Tensor) -> torch.Tensor:
345
+ """Projects a batch of waveforms to unquantized latent space.
346
+
347
+ Args:
348
+ x (torch.Tensor): Float tensor of shape [B, C, T].
349
+
350
+ Returns:
351
+ Unquantized embeddings.
352
+ """
353
+ assert (
354
+ x.dim() == 3
355
+ ), f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
356
+ state = self._streaming_state
357
+ with self._context_for_encoder_decoder:
358
+ emb = self.encoder(x)
359
+ if self.encoder_transformer is not None:
360
+ if state is None:
361
+ (emb,) = self.encoder_transformer(emb)
362
+ else:
363
+ assert state.graphed_tr_enc is not None
364
+ (emb,) = state.graphed_tr_enc(emb)
365
+ emb = self._to_framerate(emb)
366
+ return emb
367
+
368
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
369
+ """Encode the given input tensor to quantized representation.
370
+
371
+ Args:
372
+ x (torch.Tensor): Float tensor of shape [B, C, T]
373
+
374
+ Returns:
375
+ codes (torch.Tensor): an int tensor of shape [B, K, T]
376
+ with K the number of codebooks used and T the timestep.
377
+ """
378
+ emb = self._encode_to_unquantized_latent(x)
379
+ codes = self.quantizer.encode(emb)
380
+ return codes
381
+
382
+ def encode_to_latent(self, x: torch.Tensor, quantize: bool = True) -> torch.Tensor:
383
+ """Projects a batch of waveforms to latent space.
384
+
385
+ Args:
386
+ x (torch.Tensor): Float tensor of shape [B, C, T].
387
+
388
+ Returns:
389
+ Embeddings, either quantized or not.
390
+ """
391
+ emb = self._encode_to_unquantized_latent(x)
392
+ if not quantize:
393
+ return emb
394
+ else:
395
+ codes = self.quantizer.encode(emb)
396
+ return self.decode_latent(codes)
397
+
398
+ def decode(self, codes: torch.Tensor):
399
+ """Decode the given codes to a reconstructed representation.
400
+
401
+ Args:
402
+ codes (torch.Tensor): Int tensor of shape [B, K, T]
403
+
404
+ Returns:
405
+ out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
406
+ """
407
+ state = self._streaming_state
408
+ emb = self.decode_latent(codes)
409
+ emb = self._to_encoder_framerate(emb)
410
+ if self.decoder_transformer is not None:
411
+ if state is None:
412
+ (emb,) = self.decoder_transformer(emb)
413
+ else:
414
+ assert state.graphed_tr_dec is not None
415
+ (emb,) = state.graphed_tr_dec(emb)
416
+ with self._context_for_encoder_decoder:
417
+ out = self.decoder(emb)
418
+ # out contains extra padding added by the encoder and decoder
419
+ return out
420
+
421
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
422
+ """Decode from the discrete codes to continuous latent space."""
423
+ return self.quantizer.decode(codes)
424
+
425
+
426
+ class WrapperCompressionModel(CompressionModel[State]):
427
+ """Base API for CompressionModel wrappers that do not depend on external frameworks."""
428
+
429
+ def __init__(self, model: CompressionModel):
430
+ super().__init__()
431
+ self.model = model
432
+
433
+ def forward(self, x: torch.Tensor) -> QuantizedResult:
434
+ return self.model.forward(x)
435
+
436
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
437
+ return self.model.encode(x)
438
+
439
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
440
+ return self.model.decode(codes)
441
+
442
+ def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
443
+ return self.model.decode_latent(codes)
444
+
445
+ def set_num_codebooks(self, n: int):
446
+ self.model.set_num_codebooks(n)
447
+
448
+ @property
449
+ def quantizer(self):
450
+ return self.model.quantizer
451
+
452
+ @property
453
+ def channels(self) -> int:
454
+ return self.model.channels
455
+
456
+ @property
457
+ def frame_rate(self) -> float:
458
+ return self.model.frame_rate
459
+
460
+ @property
461
+ def sample_rate(self) -> int:
462
+ return self.model.sample_rate
463
+
464
+ @property
465
+ def cardinality(self) -> int:
466
+ return self.model.cardinality
467
+
468
+ @property
469
+ def num_codebooks(self) -> int:
470
+ return self.model.num_codebooks
471
+
472
+ @property
473
+ def total_codebooks(self) -> int:
474
+ return self.model.total_codebooks
moshi/models/lm.py ADDED
@@ -0,0 +1,487 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ from dataclasses import dataclass
12
+ from functools import partial
13
+ import logging
14
+ import typing as tp
15
+
16
+ import torch
17
+ from torch import nn
18
+
19
+ from moshi.utils.sampling import sample_token
20
+ from moshi.utils.compile import CUDAGraphed
21
+ from moshi.modules.streaming import StreamingContainer, StreamingModule
22
+ from moshi.modules.transformer import (
23
+ StreamingTransformer,
24
+ create_norm_fn,
25
+ )
26
+
27
+
28
+ logger = logging.getLogger(__name__)
29
+
30
+
31
+ class ScaledEmbedding(nn.Embedding):
32
+ """Boost learning rate for embeddings (with `scale`).
33
+
34
+ Args:
35
+ norm (bool): if True, uses a layer norm after the embedding.
36
+ zero_idx (int): special value indicating that the output should be exactly 0.
37
+ """
38
+
39
+ def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs):
40
+ super().__init__(*args, **kwargs)
41
+ self.norm = None
42
+ if norm:
43
+ self.norm = create_norm_fn("layer_norm", self.embedding_dim)
44
+ assert zero_idx < 0, "Please use negative values for the zero_idx."
45
+ self.zero_idx = zero_idx
46
+
47
+ def forward(self, input, *args, **kwargs):
48
+ is_zero = input == self.zero_idx
49
+ zero = torch.zeros(1, dtype=input.dtype, device=input.device)
50
+ input = input.clamp(min=0)
51
+ y = super().forward(input, *args, **kwargs)
52
+ if self.norm is not None:
53
+ y = self.norm(y)
54
+ y = torch.where(is_zero[..., None], zero, y)
55
+ return y
56
+
57
+
58
+ class LMModel(StreamingContainer):
59
+ """Transformer-based language model on multiple streams of codes.
60
+
61
+ Args:
62
+ n_q (int): Number of parallel streams to model as input.
63
+ dep_q (int): Number of parallel streams to model in the depformer.
64
+ card (int): Cardinality, vocabulary size.
65
+ text_card (int): Cardinality of the text vocabulary.
66
+ dim (int): Dimension of the transformer encoder.
67
+ num_heads (int): Number of heads for the transformer encoder.
68
+ hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
69
+ norm (str): Normalization method.
70
+ norm_emb (bool): Whether to normalize embeddings.
71
+ bias_proj (bool): Use bias for output projections.
72
+ depformer_*: params used for the Depformer Transformer, all the other will be shared.
73
+ depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the
74
+ output of the main transformer to the Depformer latent space.
75
+ depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
76
+ existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
77
+ the text padding token.
78
+ same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
79
+ **kwargs: Additional parameters for the transformer encoder.
80
+ """
81
+
82
+ def __init__(
83
+ self,
84
+ delays: tp.List[int] = [0],
85
+ n_q: int = 8,
86
+ dep_q: int = 8,
87
+ card: int = 1024,
88
+ text_card: int = 32000,
89
+ dim: int = 128,
90
+ num_heads: int = 8,
91
+ hidden_scale: int = 4,
92
+ norm: str = "layer_norm",
93
+ norm_emb: bool = False,
94
+ bias_proj: bool = False,
95
+ depformer_dim: int = 256,
96
+ depformer_dim_feedforward: int | list[int] | None = None,
97
+ depformer_multi_linear: bool = False,
98
+ depformer_weights_per_step: bool = False,
99
+ depformer_pos_emb: str = "sin",
100
+ existing_text_padding_id: tp.Optional[int] = None,
101
+ context: tp.Optional[int] = None,
102
+ device=None,
103
+ dtype=None,
104
+ **kwargs,
105
+ ):
106
+ super().__init__()
107
+ self.n_q = n_q
108
+ self.dep_q = dep_q
109
+ self.card = card
110
+ self.text_card = text_card
111
+ assert len(delays) == self.num_codebooks, "unexpected number of delays"
112
+ self.delays = delays
113
+ self.dim = dim
114
+ self.existing_text_padding_id = existing_text_padding_id
115
+ self.context = context
116
+ kwargs["context"] = context
117
+ EmbeddingFactory = partial(
118
+ ScaledEmbedding,
119
+ norm=norm_emb,
120
+ device=device,
121
+ dtype=dtype,
122
+ zero_idx=self.zero_token_id,
123
+ )
124
+ self.emb = nn.ModuleList(
125
+ [EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
126
+ )
127
+ # Text card + padding token (if not in the original tokenizer)
128
+ extra_text = self.existing_text_padding_id is None
129
+ # Unlike for audio, here we authorize the model to output the special token.
130
+ self.text_emb = EmbeddingFactory(text_card + 1, dim)
131
+ self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
132
+ depformer_prefix = "depformer_"
133
+ main_kwargs = {
134
+ k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
135
+ }
136
+ self.transformer = StreamingTransformer(
137
+ d_model=dim,
138
+ num_heads=num_heads,
139
+ dim_feedforward=int(hidden_scale * dim),
140
+ norm=norm,
141
+ device=device,
142
+ dtype=dtype,
143
+ **main_kwargs,
144
+ )
145
+ self.out_norm = create_norm_fn(norm, dim)
146
+ self.depformer_multi_linear = depformer_multi_linear
147
+ kwargs_dep = main_kwargs.copy()
148
+ kwargs_dep.update(
149
+ {
150
+ k.removeprefix(depformer_prefix): v
151
+ for k, v in kwargs.items()
152
+ if k.startswith(depformer_prefix)
153
+ }
154
+ )
155
+ kwargs_dep["positional_embedding"] = depformer_pos_emb
156
+ kwargs_dep["context"] = None
157
+ if depformer_weights_per_step:
158
+ kwargs_dep["weights_per_step"] = dep_q
159
+ if depformer_multi_linear:
160
+ # One linear layer per codebook to project different informations from the main model.
161
+ self.depformer_in = nn.ModuleList(
162
+ [nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)]
163
+ )
164
+ else:
165
+ self.depformer_in = nn.ModuleList(
166
+ [nn.Linear(dim, depformer_dim, bias=False)]
167
+ )
168
+ # Only using up to dep_q - 1 because the last codebook is never an input to Depformer.
169
+ self.depformer_emb = nn.ModuleList(
170
+ [EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)]
171
+ )
172
+ self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim)
173
+ if depformer_dim_feedforward is None:
174
+ depformer_dim_feedforward = int(hidden_scale * depformer_dim)
175
+ self.depformer = StreamingTransformer(
176
+ d_model=depformer_dim,
177
+ dim_feedforward=depformer_dim_feedforward,
178
+ norm=norm,
179
+ device=device,
180
+ dtype=dtype,
181
+ **kwargs_dep,
182
+ )
183
+ self.depformer.set_streaming_propagate(False)
184
+ dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.
185
+
186
+ self.linears = nn.ModuleList(
187
+ [nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)]
188
+ )
189
+
190
+ @property
191
+ def initial_token_id(self) -> int:
192
+ """Token id for the start of sequence (audio)."""
193
+ return self.card
194
+
195
+ @property
196
+ def text_initial_token_id(self) -> int:
197
+ """Token id for the start of sequence (text)."""
198
+ return self.text_card
199
+
200
+ @property
201
+ def text_padding_token_id(self) -> int:
202
+ """Token id for text padding."""
203
+ if self.existing_text_padding_id is None:
204
+ return self.text_card
205
+ else:
206
+ return self.existing_text_padding_id
207
+
208
+ @property
209
+ def end_of_text_padding_id(self) -> int:
210
+ """Token id for optionally marking the last padding step for a word."""
211
+ return 0
212
+
213
+ @property
214
+ def zero_token_id(self) -> int:
215
+ """Special value in the input tokens, indicating that no sampling should
216
+ happen for that value, and no input should be given to the model."""
217
+ return -1
218
+
219
+ @property
220
+ def ungenerated_token_id(self) -> int:
221
+ """Special value that can be provided in the prompt to indicate that this specific
222
+ value should be predicted and sampled. This allows for partial teacher forcing, by generating
223
+ one modality, with the other one fixed.
224
+ """
225
+ return -2
226
+
227
+ @property
228
+ def device(self):
229
+ first_param = next(iter(self.parameters()))
230
+ return first_param.device
231
+
232
+ @property
233
+ def num_codebooks(self) -> int:
234
+ return self.n_q + 1
235
+
236
+ @property
237
+ def num_audio_codebooks(self) -> int:
238
+ return self.n_q
239
+
240
+ @property
241
+ def audio_offset(self) -> int:
242
+ return 1
243
+
244
+ def _get_initial_token(self) -> torch.Tensor:
245
+ # Returns the initial token that will be fed to the model to predict the very first timestep.
246
+ # The output shape will be [B, K, 1].
247
+ device = next(iter(self.parameters())).device
248
+ zero = torch.full(
249
+ [1, 1, 1], self.zero_token_id, device=device, dtype=torch.long
250
+ )
251
+ special = torch.full_like(zero, self.initial_token_id)
252
+
253
+ text_special = torch.full_like(zero, self.text_initial_token_id)
254
+ audio_token = special
255
+ text_token = text_special
256
+ audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1)
257
+ token = torch.cat([text_token, audio_token], dim=1)
258
+ return token
259
+
260
+ def forward_text(
261
+ self,
262
+ sequence: torch.Tensor,
263
+ ) -> tuple[torch.Tensor, torch.Tensor]:
264
+ B, K, S = sequence.shape
265
+ assert (
266
+ K == self.num_codebooks
267
+ ), f"Sequence shape {sequence.shape} must match the number of codebooks."
268
+ input_sequence = sequence
269
+ input_ = None
270
+ for cb_index in range(self.num_audio_codebooks):
271
+ audio_emb = self.emb[cb_index](
272
+ input_sequence[:, cb_index + self.audio_offset]
273
+ )
274
+ input_ = audio_emb if input_ is None else input_ + audio_emb
275
+ text_emb = self.text_emb(input_sequence[:, 0])
276
+ input_ = text_emb if input_ is None else input_ + text_emb
277
+ transformer_out = self.transformer(input_)
278
+
279
+ if self.out_norm:
280
+ transformer_out = self.out_norm(transformer_out)
281
+ assert isinstance(transformer_out, torch.Tensor)
282
+ text_logits = self.text_linear(transformer_out)
283
+ text_logits = text_logits[:, None]
284
+ return transformer_out, text_logits
285
+
286
+ def forward_depformer(
287
+ self,
288
+ depformer_cb_index: int,
289
+ sequence: torch.Tensor,
290
+ transformer_out: torch.Tensor,
291
+ ) -> torch.Tensor:
292
+ B, K, S = sequence.shape
293
+ assert (
294
+ K == 1
295
+ ), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
296
+ assert (
297
+ S == 1
298
+ ), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
299
+ assert (
300
+ transformer_out.shape[1] == 1
301
+ ), "Transformer out should be a for a single step."
302
+ last_token_input: tp.Optional[torch.Tensor] = None
303
+ depformer_input = transformer_out
304
+ if self.depformer_multi_linear:
305
+ depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
306
+ else:
307
+ depformer_input = self.depformer_in[0](depformer_input)
308
+ if depformer_cb_index == 0:
309
+ last_token_input = self.depformer_text_emb(sequence[:, 0])
310
+ else:
311
+ last_token_input = self.depformer_emb[depformer_cb_index - 1](
312
+ sequence[:, 0]
313
+ )
314
+ depformer_input = depformer_input + last_token_input
315
+ assert depformer_input.shape[1] == 1
316
+ # depformer_input is [B, 1, depformer_dim].
317
+ # The streaming state of the depformer ensures that the proper layer is run.
318
+ dep_output = self.depformer(depformer_input)
319
+ logits = self.linears[depformer_cb_index](dep_output)
320
+ logits = logits[:, None]
321
+ assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
322
+ return logits
323
+
324
+
325
+ @dataclass
326
+ class _LMGenState:
327
+ cache: torch.Tensor
328
+ initial: torch.Tensor
329
+ graphed_main: CUDAGraphed
330
+ graphed_depth: CUDAGraphed
331
+ offset: int = 0
332
+
333
+ def reset(self):
334
+ self.offset = 0
335
+
336
+
337
+ class LMGen(StreamingModule[_LMGenState]):
338
+ def __init__(
339
+ self,
340
+ lm_model: LMModel,
341
+ use_sampling: bool = True,
342
+ temp: float = 0.8,
343
+ temp_text: float = 0.7,
344
+ top_k: int = 250,
345
+ top_k_text: int = 25,
346
+ check: bool = False,
347
+ ):
348
+ assert not lm_model.training, "generation shouldn't be used in training mode."
349
+ super().__init__()
350
+
351
+ self.lm_model = lm_model
352
+ self.use_sampling = use_sampling
353
+ self.temp = temp
354
+ self.temp_text = temp_text
355
+ self.top_k = top_k
356
+ self.top_k_text = top_k_text
357
+ self.check = check
358
+ self.max_delay = max(
359
+ lm_model.delays
360
+ ) # with delays, we need to generate a few more time steps.
361
+ self.delays_cuda = torch.tensor(
362
+ lm_model.delays, device=lm_model.device, dtype=torch.long
363
+ )
364
+
365
+ def _init_streaming_state(self, batch_size: int) -> _LMGenState:
366
+ lm_model = self.lm_model
367
+ initial = lm_model._get_initial_token()
368
+ cache = torch.full(
369
+ (batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
370
+ lm_model.ungenerated_token_id,
371
+ device=lm_model.device,
372
+ dtype=torch.long,
373
+ )
374
+
375
+ disable = lm_model.device.type != 'cuda'
376
+ graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
377
+ graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)
378
+
379
+ return _LMGenState(cache, initial, graphed_main, graphed_depth)
380
+
381
+ @torch.no_grad()
382
+ def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None:
383
+ state = self._streaming_state
384
+ if state is None:
385
+ raise RuntimeError(
386
+ "You should wrap those calls with a `with lm_gen.streaming(): ...`."
387
+ )
388
+ lm_model = self.lm_model
389
+
390
+ assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
391
+ B, Ki, S = input_tokens.shape
392
+ assert S == 1, "Only support being given steps one by one."
393
+ needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1
394
+ assert (
395
+ Ki == needed_tokens
396
+ ), f"We expect {needed_tokens} tokens from the user stream, got {Ki}."
397
+
398
+ CT = state.cache.shape[2]
399
+ for q_other in range(input_tokens.shape[1]):
400
+ k = lm_model.dep_q + 1 + q_other
401
+ delay = lm_model.delays[k]
402
+ write_position = (state.offset + delay) % CT
403
+ state.cache[:, k, write_position : write_position + 1] = input_tokens[
404
+ :, q_other
405
+ ]
406
+
407
+ position = state.offset % CT
408
+ for k, delay in enumerate(lm_model.delays):
409
+ # Only for the very beginning, we extend the initial token for the acoustic
410
+ # token that are delayed, and thus have no good value to take.
411
+ if state.offset <= delay:
412
+ state.cache[:, k, position] = state.initial[:, k, 0]
413
+ input_ = state.cache[:, :, position : position + 1]
414
+
415
+ if self.check:
416
+ # Check that we are not feeding in any value that is not generated yet.
417
+ assert not (input_ == lm_model.ungenerated_token_id).any(), (
418
+ state.offset,
419
+ input_,
420
+ )
421
+ assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_
422
+ assert (input_[:, :1] <= lm_model.text_card).all()
423
+
424
+ transformer_out, text_logits = state.graphed_main(input_)
425
+ # Shape of text_logits should be [B, K_text=1, T=1, Card_text]
426
+ text_token = sample_token(
427
+ text_logits.float(),
428
+ self.use_sampling,
429
+ self.temp_text,
430
+ self.top_k_text,
431
+ )
432
+ assert text_token.dim() == 3, text_token.shape
433
+ assert text_token.shape[2] == 1
434
+ assert text_token.shape[1] == 1, "Only one text stream supported."
435
+ text_token = text_token[:, 0, 0] # shape is [B]
436
+ audio_tokens = state.graphed_depth(text_token, transformer_out)
437
+
438
+ # ensure we don't overwrite prompt tokens, we only write over ungenerated tokens
439
+ state.offset += 1
440
+ position = state.offset % CT
441
+ state.cache[:, 0, position] = text_token
442
+ state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
443
+
444
+ if state.offset <= self.max_delay:
445
+ return None
446
+ B = state.cache.shape[0]
447
+ gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1]
448
+ index = (
449
+ ((state.offset - self.max_delay + gen_delays_cuda) % CT)
450
+ .view(1, -1, 1)
451
+ .expand(B, -1, 1)
452
+ )
453
+ out = state.cache.gather(dim=2, index=index)
454
+ return out
455
+
456
+ def depformer_step(
457
+ self,
458
+ text_token: torch.Tensor,
459
+ transformer_out: torch.Tensor,
460
+ ) -> torch.Tensor:
461
+ (B,) = text_token.shape
462
+ prev_token = text_token
463
+ lm_model = self.lm_model
464
+ depformer_tokens: list[torch.Tensor] = []
465
+ assert not lm_model.depformer.is_streaming
466
+ with lm_model.depformer.streaming(B):
467
+ for cb_index in range(lm_model.dep_q):
468
+ input_ = prev_token[:, None, None]
469
+ logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
470
+ next_token = sample_token(
471
+ logits.float(),
472
+ self.use_sampling,
473
+ self.temp,
474
+ self.top_k,
475
+ )
476
+ assert next_token.shape == (B, 1, 1)
477
+ next_token = next_token[:, 0, 0] # shape is B
478
+ depformer_tokens.append(next_token)
479
+ prev_token = next_token
480
+
481
+ assert len(depformer_tokens) == lm_model.dep_q, (
482
+ len(depformer_tokens),
483
+ lm_model.dep_q,
484
+ )
485
+ out = torch.stack(depformer_tokens, dim=1)
486
+ assert out.shape == (B, lm_model.dep_q), out.shape
487
+ return out
moshi/models/loaders.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+ """Retrieves the pretrained models for Moshi and Mimi."""
5
+ from pathlib import Path
6
+
7
+ from safetensors.torch import load_model
8
+ import torch
9
+
10
+ from moshi.models.compression import MimiModel
11
+ from moshi.models.lm import LMModel
12
+ from moshi.modules import SEANetEncoder, SEANetDecoder, transformer
13
+ from moshi.quantization import SplitResidualVectorQuantizer
14
+
15
+ SAMPLE_RATE = 24000
16
+ FRAME_RATE = 12.5
17
+
18
+ TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
19
+ MOSHI_NAME = 'model.safetensors'
20
+ MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
21
+ DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
22
+
23
+
24
+ _seanet_kwargs = {
25
+ "channels": 1,
26
+ "dimension": 512,
27
+ "causal": True,
28
+ "n_filters": 64,
29
+ "n_residual_layers": 1,
30
+ "activation": "ELU",
31
+ "compress": 2,
32
+ "dilation_base": 2,
33
+ "disable_norm_outer_blocks": 0,
34
+ "kernel_size": 7,
35
+ "residual_kernel_size": 3,
36
+ "last_kernel_size": 3,
37
+ # We train using weight_norm but then the weights are pre-processed for inference so
38
+ # that we can use a normal convolution.
39
+ "norm": "none",
40
+ "pad_mode": "constant",
41
+ "ratios": [8, 6, 5, 4],
42
+ "true_skip": True,
43
+ }
44
+ _quantizer_kwargs = {
45
+ "dimension": 256,
46
+ "n_q": 32,
47
+ "bins": 2048,
48
+ "input_dimension": _seanet_kwargs["dimension"],
49
+ "output_dimension": _seanet_kwargs["dimension"],
50
+ }
51
+ _transformer_kwargs = {
52
+ "d_model": _seanet_kwargs["dimension"],
53
+ "num_heads": 8,
54
+ "num_layers": 8,
55
+ "causal": True,
56
+ "layer_scale": 0.01,
57
+ "context": 250,
58
+ "conv_layout": True,
59
+ "max_period": 10000,
60
+ "gating": "none",
61
+ "norm": "layer_norm",
62
+ "positional_embedding": "rope",
63
+ "dim_feedforward": 2048,
64
+ "input_dimension": _seanet_kwargs["dimension"],
65
+ "output_dimensions": [_seanet_kwargs["dimension"]],
66
+ }
67
+
68
+ _lm_kwargs = {
69
+ "dim": 4096,
70
+ "text_card": 32000,
71
+ "existing_text_padding_id": 3,
72
+ "n_q": 16,
73
+ "dep_q": 8,
74
+ "card": _quantizer_kwargs["bins"],
75
+ "num_heads": 32,
76
+ "num_layers": 32,
77
+ "hidden_scale": 4.125,
78
+ "causal": True,
79
+ "layer_scale": None,
80
+ "context": 3000,
81
+ "max_period": 10000,
82
+ "gating": "silu",
83
+ "norm": "rms_norm_f32",
84
+ "positional_embedding": "rope",
85
+ "depformer_dim": 1024,
86
+ "depformer_dim_feedforward": int(4.125 * 1024),
87
+ "depformer_num_heads": 16,
88
+ "depformer_num_layers": 6,
89
+ "depformer_causal": True,
90
+ "depformer_layer_scale": None,
91
+ "depformer_multi_linear": True,
92
+ "depformer_context": 8,
93
+ "depformer_max_period": 10000,
94
+ "depformer_gating": "silu",
95
+ "depformer_pos_emb": "none",
96
+ "depformer_weights_per_step": True,
97
+ "delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
98
+ }
99
+
100
+
101
+ def _is_safetensors(path: Path | str) -> bool:
102
+ return Path(path).suffix in (".safetensors", ".sft", ".sfts")
103
+
104
+
105
+ def get_mimi(filename: str | Path,
106
+ device: torch.device | str = 'cpu') -> MimiModel:
107
+ """Return a pretrained Mimi model."""
108
+ encoder = SEANetEncoder(**_seanet_kwargs)
109
+ decoder = SEANetDecoder(**_seanet_kwargs)
110
+ encoder_transformer = transformer.ProjectedTransformer(
111
+ device=device, **_transformer_kwargs
112
+ )
113
+ decoder_transformer = transformer.ProjectedTransformer(
114
+ device=device, **_transformer_kwargs
115
+ )
116
+ quantizer = SplitResidualVectorQuantizer(
117
+ **_quantizer_kwargs,
118
+ )
119
+ model = MimiModel(
120
+ encoder,
121
+ decoder,
122
+ quantizer,
123
+ channels=1,
124
+ sample_rate=SAMPLE_RATE,
125
+ frame_rate=FRAME_RATE,
126
+ encoder_frame_rate=SAMPLE_RATE / encoder.hop_length,
127
+ causal=True,
128
+ resample_method="conv",
129
+ encoder_transformer=encoder_transformer,
130
+ decoder_transformer=decoder_transformer,
131
+ ).to(device=device)
132
+ model.eval()
133
+ if _is_safetensors(filename):
134
+ load_model(model, filename)
135
+ else:
136
+ pkg = torch.load(filename, "cpu")
137
+ model.load_state_dict(pkg["model"])
138
+ model.set_num_codebooks(8)
139
+ return model
140
+
141
+
142
+ def get_moshi_lm(filename: str | Path,
143
+ device: torch.device | str = 'cpu') -> LMModel:
144
+ dtype = torch.bfloat16
145
+ model = LMModel(
146
+ device=device,
147
+ dtype=dtype,
148
+ **_lm_kwargs,
149
+ ).to(device=device, dtype=dtype)
150
+ model.eval()
151
+ if _is_safetensors(filename):
152
+ load_model(model, filename)
153
+ else:
154
+ pkg = torch.load(
155
+ filename,
156
+ "cpu",
157
+ )
158
+ model.load_state_dict(pkg["fsdp_best_state"]["model"])
159
+ return model
moshi/modules/__init__.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """Modules used for building the models."""
11
+
12
+ # flake8: noqa
13
+ from .conv import (
14
+ NormConv1d,
15
+ NormConvTranspose1d,
16
+ StreamingConv1d,
17
+ StreamingConvTranspose1d,
18
+ pad_for_conv1d,
19
+ pad1d,
20
+ unpad1d,
21
+ )
22
+ from .seanet import SEANetEncoder, SEANetDecoder
23
+ from .transformer import StreamingTransformer
moshi/modules/conv.py ADDED
@@ -0,0 +1,329 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ from dataclasses import dataclass
12
+ import math
13
+ import typing as tp
14
+ import warnings
15
+
16
+ import torch
17
+ from torch import nn
18
+ from torch.nn import functional as F
19
+ from torch.nn.utils import weight_norm
20
+
21
+ from .streaming import RawStreamingConv1d, RawStreamingConvTranspose1d, StreamingModule
22
+
23
+
24
+ CONV_NORMALIZATIONS = frozenset(["none", "weight_norm"])
25
+
26
+
27
+ class TransposedLayerNorm(nn.Module):
28
+ """LayerNorm for [B, C, T] inputs."""
29
+
30
+ def __init__(self, **kwargs):
31
+ super().__init__()
32
+ self.layer_norm = nn.LayerNorm(**kwargs)
33
+
34
+ def forward(self, x):
35
+ x = x.transpose(1, 2)
36
+ x = self.layer_norm(x)
37
+ return x.transpose(1, 2)
38
+
39
+
40
+ def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
41
+ assert norm in CONV_NORMALIZATIONS
42
+ if norm == "weight_norm":
43
+ return weight_norm(module)
44
+ else:
45
+ # We already check was in CONV_NORMALIZATION, so any other choice
46
+ # doesn't need reparametrization.
47
+ return module
48
+
49
+
50
+ def get_extra_padding_for_conv1d(
51
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
52
+ ) -> int:
53
+ """See `pad_for_conv1d`."""
54
+ length = x.shape[-1]
55
+ n_frames = (length - kernel_size + padding_total) / stride + 1
56
+ ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
57
+ return ideal_length - length
58
+
59
+
60
+ def pad_for_conv1d(
61
+ x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
62
+ ):
63
+ """Pad for a convolution to make sure that the last window is full.
64
+ Extra padding is added at the end. This is required to ensure that we can rebuild
65
+ an output of the same length, as otherwise, even with padding, some time steps
66
+ might get removed.
67
+ For instance, with total padding = 4, kernel size = 4, stride = 2:
68
+ 0 0 1 2 3 4 5 0 0 # (0s are padding)
69
+ 1 2 3 # (output frames of a convolution, last 0 is never used)
70
+ 0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
71
+ 1 2 3 4 # once you removed padding, we are missing one time step !
72
+ """
73
+ extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
74
+ return F.pad(x, (0, extra_padding))
75
+
76
+
77
+ def pad1d(
78
+ x: torch.Tensor,
79
+ paddings: tp.Tuple[int, int],
80
+ mode: str = "constant",
81
+ value: float = 0.0,
82
+ ):
83
+ """Tiny wrapper around F.pad, just to allow for reflect padding on small input.
84
+ If this is the case, we insert extra 0 padding to the right before the reflection happen.
85
+ """
86
+ length = x.shape[-1]
87
+ padding_left, padding_right = paddings
88
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
89
+ if mode == "reflect":
90
+ max_pad = max(padding_left, padding_right)
91
+ extra_pad = 0
92
+ if length <= max_pad:
93
+ extra_pad = max_pad - length + 1
94
+ x = F.pad(x, (0, extra_pad))
95
+ padded = F.pad(x, paddings, mode, value)
96
+ end = padded.shape[-1] - extra_pad
97
+ return padded[..., :end]
98
+ else:
99
+ return F.pad(x, paddings, mode, value)
100
+
101
+
102
+ def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
103
+ """Remove padding from x, handling properly zero padding. Only for 1d!"""
104
+ padding_left, padding_right = paddings
105
+ assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
106
+ assert (padding_left + padding_right) <= x.shape[-1]
107
+ end = x.shape[-1] - padding_right
108
+ return x[..., padding_left:end]
109
+
110
+
111
+ class NormConv1d(nn.Module):
112
+ """Wrapper around Conv1d and normalization applied to this conv
113
+ to provide a uniform interface across normalization approaches.
114
+ """
115
+
116
+ def __init__(
117
+ self,
118
+ *args,
119
+ causal: bool = False,
120
+ norm: str = "none",
121
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
122
+ **kwargs,
123
+ ):
124
+ super().__init__()
125
+ self.conv = apply_parametrization_norm(
126
+ RawStreamingConv1d(*args, **kwargs), norm
127
+ )
128
+ self.norm_type = norm
129
+
130
+ def forward(self, x):
131
+ x = self.conv(x)
132
+ return x
133
+
134
+
135
+ class NormConvTranspose1d(nn.Module):
136
+ """Wrapper around ConvTranspose1d and normalization applied to this conv
137
+ to provide a uniform interface across normalization approaches.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ *args,
143
+ causal: bool = False,
144
+ norm: str = "none",
145
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
146
+ **kwargs,
147
+ ):
148
+ super().__init__()
149
+ self.convtr = apply_parametrization_norm(
150
+ RawStreamingConvTranspose1d(*args, **kwargs), norm
151
+ )
152
+ self.norm_type = norm
153
+
154
+ def forward(self, x):
155
+ x = self.convtr(x)
156
+ return x
157
+
158
+
159
+ @dataclass
160
+ class _StreamingConv1dState:
161
+ padding_to_add: int
162
+ original_padding_to_add: int
163
+
164
+ def reset(self):
165
+ self.padding_to_add = self.original_padding_to_add
166
+
167
+
168
+ class StreamingConv1d(StreamingModule[_StreamingConv1dState]):
169
+ """Conv1d with some builtin handling of asymmetric or causal padding
170
+ and normalization.
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ in_channels: int,
176
+ out_channels: int,
177
+ kernel_size: int,
178
+ stride: int = 1,
179
+ dilation: int = 1,
180
+ groups: int = 1,
181
+ bias: bool = True,
182
+ causal: bool = False,
183
+ norm: str = "none",
184
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
185
+ pad_mode: str = "reflect",
186
+ ):
187
+ super().__init__()
188
+ # warn user on unusual setup between dilation and stride
189
+ if stride > 1 and dilation > 1:
190
+ warnings.warn(
191
+ "StreamingConv1d has been initialized with stride > 1 and dilation > 1"
192
+ f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
193
+ )
194
+ self.conv = NormConv1d(
195
+ in_channels,
196
+ out_channels,
197
+ kernel_size,
198
+ stride,
199
+ dilation=dilation,
200
+ groups=groups,
201
+ bias=bias,
202
+ causal=causal,
203
+ norm=norm,
204
+ norm_kwargs=norm_kwargs,
205
+ )
206
+ self.causal = causal
207
+ self.pad_mode = pad_mode
208
+
209
+ @property
210
+ def _stride(self) -> int:
211
+ return self.conv.conv.stride[0]
212
+
213
+ @property
214
+ def _kernel_size(self) -> int:
215
+ return self.conv.conv.kernel_size[0]
216
+
217
+ @property
218
+ def _effective_kernel_size(self) -> int:
219
+ dilation = self.conv.conv.dilation[0]
220
+ return (
221
+ self._kernel_size - 1
222
+ ) * dilation + 1 # effective kernel size with dilations
223
+
224
+ @property
225
+ def _padding_total(self) -> int:
226
+ return self._effective_kernel_size - self._stride
227
+
228
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConv1dState:
229
+ assert self.causal, "streaming is only supported for causal convs"
230
+ return _StreamingConv1dState(self._padding_total, self._padding_total)
231
+
232
+ def forward(self, x):
233
+ B, C, T = x.shape
234
+ padding_total = self._padding_total
235
+ extra_padding = get_extra_padding_for_conv1d(
236
+ x, self._effective_kernel_size, self._stride, padding_total
237
+ )
238
+ state = self._streaming_state
239
+ if state is None:
240
+ if self.causal:
241
+ # Left padding for causal
242
+ x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
243
+ else:
244
+ # Asymmetric padding required for odd strides
245
+ padding_right = padding_total // 2
246
+ padding_left = padding_total - padding_right
247
+ x = pad1d(
248
+ x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
249
+ )
250
+ else:
251
+ if state.padding_to_add > 0 and x.shape[-1] > 0:
252
+ x = pad1d(x, (state.padding_to_add, 0), mode=self.pad_mode)
253
+ state.padding_to_add = 0
254
+ return self.conv(x)
255
+
256
+
257
+ @dataclass
258
+ class _StreamingConvTr1dState:
259
+ pass
260
+
261
+ def reset(self):
262
+ pass
263
+
264
+
265
+ class StreamingConvTranspose1d(StreamingModule[_StreamingConvTr1dState]):
266
+ """ConvTranspose1d with some builtin handling of asymmetric or causal padding
267
+ and normalization.
268
+ """
269
+
270
+ def __init__(
271
+ self,
272
+ in_channels: int,
273
+ out_channels: int,
274
+ kernel_size: int,
275
+ stride: int = 1,
276
+ groups: int = 1,
277
+ bias: bool = True,
278
+ causal: bool = False,
279
+ norm: str = "none",
280
+ trim_right_ratio: float = 1.0,
281
+ norm_kwargs: tp.Dict[str, tp.Any] = {},
282
+ ):
283
+ super().__init__()
284
+ self.convtr = NormConvTranspose1d(
285
+ in_channels,
286
+ out_channels,
287
+ kernel_size,
288
+ stride,
289
+ groups=groups,
290
+ bias=bias,
291
+ causal=causal,
292
+ norm=norm,
293
+ norm_kwargs=norm_kwargs,
294
+ )
295
+ self.causal = causal
296
+ self.trim_right_ratio = trim_right_ratio
297
+ assert (
298
+ self.causal or self.trim_right_ratio == 1.0
299
+ ), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
300
+ assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
301
+
302
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvTr1dState:
303
+ assert self.causal, "streaming is only supported for causal convtrs"
304
+ return _StreamingConvTr1dState()
305
+
306
+ def forward(self, x):
307
+ kernel_size = self.convtr.convtr.kernel_size[0]
308
+ stride = self.convtr.convtr.stride[0]
309
+ padding_total = kernel_size - stride
310
+
311
+ y = self.convtr(x)
312
+
313
+ if not self.is_streaming:
314
+ # We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
315
+ # removed at the very end, when keeping only the right length for the output,
316
+ # as removing it here would require also passing the length at the matching layer
317
+ # in the encoder.
318
+ if self.causal:
319
+ # Trim the padding on the right according to the specified ratio
320
+ # if trim_right_ratio = 1.0, trim everything from right
321
+ padding_right = math.ceil(padding_total * self.trim_right_ratio)
322
+ padding_left = padding_total - padding_right
323
+ y = unpad1d(y, (padding_left, padding_right))
324
+ else:
325
+ # Asymmetric padding required for odd strides
326
+ padding_right = padding_total // 2
327
+ padding_left = padding_total - padding_right
328
+ y = unpad1d(y, (padding_left, padding_right))
329
+ return y
moshi/modules/gating.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import torch
6
+ from torch import nn
7
+ from torch.nn import functional as F
8
+
9
+ from ..utils.compile import torch_compile_lazy
10
+
11
+
12
+ @torch_compile_lazy
13
+ def gating_forward_kernel(
14
+ weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor
15
+ ):
16
+ x = F.linear(x, weight_in)
17
+ B, T, _ = x.shape
18
+ x = x.view(B, T, 2, -1)
19
+ x = activation(x[..., 0, :]) * x[..., 1, :]
20
+ x = F.linear(x, weight_out)
21
+ return x
22
+
23
+
24
+ class ActivationGating(nn.Module):
25
+ """
26
+ Gating FFN layer, using the given activation.
27
+ Args:
28
+ dim (int): dimension of the input and output of the transformer.
29
+ activation (any callable Tensor to Tensor): activation function to use.
30
+ **factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype.
31
+ """
32
+
33
+ _fsdp_final = True
34
+
35
+ def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
36
+ super().__init__()
37
+ # We should have 8 d^2 param, instead we will have
38
+ # 2 * h * d + h * d = 3 h * d = 8 d^2
39
+ # so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
40
+ if dim_feedforward == 4 * dim:
41
+ hidden = (21 * dim) // 8
42
+ else:
43
+ hidden = (2 * dim_feedforward) // 3
44
+ self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
45
+ self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
46
+ self.activation = activation
47
+
48
+ def forward(self, x: torch.Tensor):
49
+ return gating_forward_kernel(
50
+ self.linear_in.weight, self.linear_out.weight, self.activation, x
51
+ )
52
+
53
+
54
+ def _get_activation(name: str):
55
+ if name in ["sigmoid", "tanh", "relu"]:
56
+ return getattr(torch, name)
57
+ elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
58
+ return getattr(torch.nn.functional, name)
59
+ elif name == "identity":
60
+ return torch.nn.Identity()
61
+ else:
62
+ raise ValueError(f"Unknown activation {name}")
63
+
64
+
65
+ def _make_gating(
66
+ name: str, dim: int, dim_feedforward: int, **factory_kwargs
67
+ ) -> nn.Module:
68
+ return ActivationGating(
69
+ dim, dim_feedforward, _get_activation(name), **factory_kwargs
70
+ )
71
+
72
+
73
+ def make_gating(
74
+ name: str, dim: int, dim_feedforward: int, **factory_kwargs
75
+ ) -> nn.Module:
76
+ gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
77
+ max_params = 2 * dim * dim_feedforward
78
+ params = sum(p.numel() for p in gating.parameters())
79
+ assert (
80
+ params <= max_params
81
+ ), f"{name} gating has {params} params, max is {max_params}"
82
+ return gating
moshi/modules/resample.py ADDED
@@ -0,0 +1,119 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ import typing as tp
6
+
7
+ from einops import rearrange
8
+ import torch
9
+ from torch import nn
10
+
11
+ from .conv import StreamingConv1d, StreamingConvTranspose1d
12
+
13
+
14
+ class ConvDownsample1d(nn.Module):
15
+ """
16
+ Downsampling by some integer amount `stride` using convolutions
17
+ with a kernel size of twice the stride.
18
+ If `causal` is True, the output uses a causal convolution.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ stride: int,
24
+ dimension: tp.Optional[int] = None,
25
+ causal: bool = False,
26
+ learnt: bool = False,
27
+ channel_wise: bool = False,
28
+ ):
29
+ super().__init__()
30
+ self.learnt = learnt
31
+ self.channel_wise = channel_wise
32
+ groups = 1
33
+ if learnt:
34
+ assert dimension is not None, "Dimension required for learnt convolutions."
35
+ in_channels = dimension
36
+ out_channels = dimension
37
+ if channel_wise:
38
+ groups = dimension
39
+ else:
40
+ in_channels = 1
41
+ out_channels = 1
42
+
43
+ self.conv = StreamingConv1d(
44
+ in_channels,
45
+ out_channels,
46
+ kernel_size=2 * stride,
47
+ stride=stride,
48
+ causal=causal,
49
+ groups=groups,
50
+ bias=False,
51
+ pad_mode="replicate",
52
+ )
53
+ if not learnt:
54
+ actual_conv = self.conv.conv.conv
55
+ actual_conv.weight.requires_grad_(False)
56
+ actual_conv.weight.data.fill_(1.0 / (2 * stride))
57
+
58
+ def forward(self, x: torch.Tensor):
59
+ batch_size = len(x)
60
+ if not self.learnt:
61
+ x = rearrange(x, "b c t -> (b c) () t")
62
+ y = self.conv(x)
63
+ if not self.learnt:
64
+ y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
65
+ return y
66
+
67
+
68
+ class ConvTrUpsample1d(nn.Module):
69
+ """
70
+ Upsample by some integer amount `stride` using transposed convolutions.
71
+ """
72
+
73
+ def __init__(
74
+ self,
75
+ stride: int,
76
+ dimension: tp.Optional[int] = None,
77
+ causal: bool = False,
78
+ learnt: bool = False,
79
+ channel_wise: bool = False,
80
+ ):
81
+ super().__init__()
82
+ self.learnt = learnt
83
+ self.channel_wise = channel_wise
84
+ groups = 1
85
+ if learnt:
86
+ assert dimension is not None, "Dimension required for learnt convolutions."
87
+ in_channels = dimension
88
+ out_channels = dimension
89
+ if channel_wise:
90
+ groups = dimension
91
+ else:
92
+ in_channels = 1
93
+ out_channels = 1
94
+
95
+ self.convtr = StreamingConvTranspose1d(
96
+ in_channels,
97
+ out_channels,
98
+ kernel_size=2 * stride,
99
+ stride=stride,
100
+ causal=causal,
101
+ groups=groups,
102
+ bias=False,
103
+ )
104
+ if not learnt:
105
+ actual_convtr = self.convtr.convtr.convtr
106
+ actual_convtr.weight.requires_grad_(False)
107
+ actual_convtr.weight.data.fill_(1.0)
108
+
109
+ def forward(self, x: torch.Tensor):
110
+ batch_size = len(x)
111
+ if not self.learnt:
112
+ x = rearrange(x, "b c t -> (b c) () t")
113
+ y = self.convtr(x)
114
+ if not self.learnt:
115
+ x_for_normalization = torch.ones_like(x[:1])
116
+ normalization = self.convtr(x_for_normalization)
117
+ y = y / normalization
118
+ y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
119
+ return y
moshi/modules/rope.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ from torch import nn
6
+ import math
7
+ import torch
8
+ from ..utils.compile import torch_compile_lazy
9
+
10
+
11
+ @torch_compile_lazy
12
+ def apply_rope(
13
+ q: torch.Tensor,
14
+ k: torch.Tensor,
15
+ offset: torch.Tensor,
16
+ max_period: float = 10_000,
17
+ time_before_heads: bool = False,
18
+ ):
19
+ """
20
+ Args:
21
+ q (torch.Tensor): queries, shape `[B, T, H, D]`.
22
+ k (torch.Tensor): keys, shape `[B, T, H, D]`.
23
+ offset (int): current offset, e.g. when streaming.
24
+ max_period (float): maximum period for the cos and sin.
25
+ time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
26
+ """
27
+
28
+ if time_before_heads:
29
+ B, T, H, D = q.shape
30
+ else:
31
+ B, H, T, D = q.shape
32
+ assert k.shape == q.shape
33
+ assert D > 0
34
+ assert D % 2 == 0
35
+ assert max_period > 0
36
+
37
+ ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
38
+ freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
39
+ ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32)
40
+ if time_before_heads:
41
+ ts = ts.view(-1, 1, 1)
42
+ else:
43
+ ts = ts.view(1, -1, 1)
44
+
45
+ dims = q.shape[:-1]
46
+ q = q.view(*dims, D // 2, 2)
47
+ k = k.view(*dims, D // 2, 2)
48
+
49
+ # convention is `r` suffix is real part, `i` is imaginary.
50
+ qr = q[..., 0].float()
51
+ qi = q[..., 1].float()
52
+
53
+ kr = k[..., 0].float()
54
+ ki = k[..., 1].float()
55
+
56
+ rotr = torch.cos(freqs * ts)
57
+ roti = torch.sin(freqs * ts)
58
+ qor = qr * rotr - qi * roti
59
+ qoi = qr * roti + qi * rotr
60
+
61
+ kor = kr * rotr - ki * roti
62
+ koi = kr * roti + ki * rotr
63
+
64
+ dtype = q.dtype
65
+ qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
66
+ ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
67
+
68
+ return qo.view(*dims, D), ko.view(*dims, D)
69
+
70
+
71
+ class RotaryEmbedding(nn.Module):
72
+ """Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
73
+
74
+ Args:
75
+ max_period (float): Maximum period of the rotation frequencies.
76
+ """
77
+
78
+ def __init__(self, max_period: float = 10000.0):
79
+ super().__init__()
80
+ self.max_period = max_period
81
+
82
+ def forward(
83
+ self,
84
+ q: torch.Tensor,
85
+ k: torch.Tensor,
86
+ offset: torch.Tensor,
87
+ time_before_heads: bool = False,
88
+ ):
89
+ """Apply rope rotation to query or key tensor."""
90
+ return apply_rope(q, k, offset, self.max_period, time_before_heads)
moshi/modules/seanet.py ADDED
@@ -0,0 +1,395 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import typing as tp
12
+
13
+ import numpy as np
14
+ import torch.nn as nn
15
+
16
+ from .conv import StreamingConv1d, StreamingConvTranspose1d
17
+ from .streaming import StreamingContainer, StreamingAdd
18
+ from ..utils.compile import torch_compile_lazy
19
+
20
+
21
+ class SEANetResnetBlock(StreamingContainer):
22
+ """Residual block from SEANet model.
23
+
24
+ Args:
25
+ dim (int): Dimension of the input/output.
26
+ kernel_sizes (list): List of kernel sizes for the convolutions.
27
+ dilations (list): List of dilations for the convolutions.
28
+ activation (str): Activation function.
29
+ activation_params (dict): Parameters to provide to the activation function.
30
+ norm (str): Normalization method.
31
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
32
+ causal (bool): Whether to use fully causal convolution.
33
+ pad_mode (str): Padding mode for the convolutions.
34
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
35
+ true_skip (bool): Whether to use true skip connection or a simple
36
+ (streamable) convolution as the skip connection.
37
+ """
38
+
39
+ def __init__(
40
+ self,
41
+ dim: int,
42
+ kernel_sizes: tp.List[int] = [3, 1],
43
+ dilations: tp.List[int] = [1, 1],
44
+ activation: str = "ELU",
45
+ activation_params: dict = {"alpha": 1.0},
46
+ norm: str = "none",
47
+ norm_params: tp.Dict[str, tp.Any] = {},
48
+ causal: bool = False,
49
+ pad_mode: str = "reflect",
50
+ compress: int = 2,
51
+ true_skip: bool = True,
52
+ ):
53
+ super().__init__()
54
+ assert len(kernel_sizes) == len(
55
+ dilations
56
+ ), "Number of kernel sizes should match number of dilations"
57
+ act = getattr(nn, activation)
58
+ hidden = dim // compress
59
+ block = []
60
+ for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
61
+ in_chs = dim if i == 0 else hidden
62
+ out_chs = dim if i == len(kernel_sizes) - 1 else hidden
63
+ block += [
64
+ act(**activation_params),
65
+ StreamingConv1d(
66
+ in_chs,
67
+ out_chs,
68
+ kernel_size=kernel_size,
69
+ dilation=dilation,
70
+ norm=norm,
71
+ norm_kwargs=norm_params,
72
+ causal=causal,
73
+ pad_mode=pad_mode,
74
+ ),
75
+ ]
76
+ self.block = nn.Sequential(*block)
77
+ self.add = StreamingAdd()
78
+ self.shortcut: nn.Module
79
+ if true_skip:
80
+ self.shortcut = nn.Identity()
81
+ else:
82
+ self.shortcut = StreamingConv1d(
83
+ dim,
84
+ dim,
85
+ kernel_size=1,
86
+ norm=norm,
87
+ norm_kwargs=norm_params,
88
+ causal=causal,
89
+ pad_mode=pad_mode,
90
+ )
91
+
92
+ def forward(self, x):
93
+ u, v = self.shortcut(x), self.block(x)
94
+ return self.add(u, v)
95
+
96
+
97
+ class SEANetEncoder(StreamingContainer):
98
+ """SEANet encoder.
99
+
100
+ Args:
101
+ channels (int): Audio channels.
102
+ dimension (int): Intermediate representation dimension.
103
+ n_filters (int): Base width for the model.
104
+ n_residual_layers (int): nb of residual layers.
105
+ ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
106
+ upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
107
+ that must match the decoder order. We use the decoder order as some models may only employ the decoder.
108
+ activation (str): Activation function.
109
+ activation_params (dict): Parameters to provide to the activation function.
110
+ norm (str): Normalization method.
111
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
112
+ kernel_size (int): Kernel size for the initial convolution.
113
+ last_kernel_size (int): Kernel size for the initial convolution.
114
+ residual_kernel_size (int): Kernel size for the residual layers.
115
+ dilation_base (int): How much to increase the dilation with each layer.
116
+ causal (bool): Whether to use fully causal convolution.
117
+ pad_mode (str): Padding mode for the convolutions.
118
+ true_skip (bool): Whether to use true skip connection or a simple
119
+ (streamable) convolution as the skip connection in the residual network blocks.
120
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
121
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
122
+ For the encoder, it corresponds to the N first blocks.
123
+ mask_fn (nn.Module): Optional mask function to apply after convolution layers.
124
+ mask_position (int): Position of the mask function, with mask_position == 0 for the first convolution layer,
125
+ mask_position == 1 for the first conv block, etc.
126
+ """
127
+
128
+ def __init__(
129
+ self,
130
+ channels: int = 1,
131
+ dimension: int = 128,
132
+ n_filters: int = 32,
133
+ n_residual_layers: int = 3,
134
+ ratios: tp.List[int] = [8, 5, 4, 2],
135
+ activation: str = "ELU",
136
+ activation_params: dict = {"alpha": 1.0},
137
+ norm: str = "none",
138
+ norm_params: tp.Dict[str, tp.Any] = {},
139
+ kernel_size: int = 7,
140
+ last_kernel_size: int = 7,
141
+ residual_kernel_size: int = 3,
142
+ dilation_base: int = 2,
143
+ causal: bool = False,
144
+ pad_mode: str = "reflect",
145
+ true_skip: bool = True,
146
+ compress: int = 2,
147
+ disable_norm_outer_blocks: int = 0,
148
+ mask_fn: tp.Optional[nn.Module] = None,
149
+ mask_position: tp.Optional[int] = None,
150
+ ):
151
+ super().__init__()
152
+ self.channels = channels
153
+ self.dimension = dimension
154
+ self.n_filters = n_filters
155
+ self.ratios = list(reversed(ratios))
156
+ del ratios
157
+ self.n_residual_layers = n_residual_layers
158
+ self.hop_length = int(np.prod(self.ratios))
159
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
160
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
161
+ assert (
162
+ self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
163
+ ), (
164
+ "Number of blocks for which to disable norm is invalid."
165
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
166
+ )
167
+
168
+ act = getattr(nn, activation)
169
+ mult = 1
170
+ model: tp.List[nn.Module] = [
171
+ StreamingConv1d(
172
+ channels,
173
+ mult * n_filters,
174
+ kernel_size,
175
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
176
+ norm_kwargs=norm_params,
177
+ causal=causal,
178
+ pad_mode=pad_mode,
179
+ )
180
+ ]
181
+ if mask_fn is not None and mask_position == 0:
182
+ model += [mask_fn]
183
+ # Downsample to raw audio scale
184
+ for i, ratio in enumerate(self.ratios):
185
+ block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
186
+ # Add residual layers
187
+ for j in range(n_residual_layers):
188
+ model += [
189
+ SEANetResnetBlock(
190
+ mult * n_filters,
191
+ kernel_sizes=[residual_kernel_size, 1],
192
+ dilations=[dilation_base**j, 1],
193
+ norm=block_norm,
194
+ norm_params=norm_params,
195
+ activation=activation,
196
+ activation_params=activation_params,
197
+ causal=causal,
198
+ pad_mode=pad_mode,
199
+ compress=compress,
200
+ true_skip=true_skip,
201
+ )
202
+ ]
203
+
204
+ # Add downsampling layers
205
+ model += [
206
+ act(**activation_params),
207
+ StreamingConv1d(
208
+ mult * n_filters,
209
+ mult * n_filters * 2,
210
+ kernel_size=ratio * 2,
211
+ stride=ratio,
212
+ norm=block_norm,
213
+ norm_kwargs=norm_params,
214
+ causal=causal,
215
+ pad_mode=pad_mode,
216
+ ),
217
+ ]
218
+ mult *= 2
219
+ if mask_fn is not None and mask_position == i + 1:
220
+ model += [mask_fn]
221
+
222
+ model += [
223
+ act(**activation_params),
224
+ StreamingConv1d(
225
+ mult * n_filters,
226
+ dimension,
227
+ last_kernel_size,
228
+ norm=(
229
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
230
+ ),
231
+ norm_kwargs=norm_params,
232
+ causal=causal,
233
+ pad_mode=pad_mode,
234
+ ),
235
+ ]
236
+
237
+ self.model = nn.Sequential(*model)
238
+
239
+ @torch_compile_lazy
240
+ def forward(self, x):
241
+ return self.model(x)
242
+
243
+
244
+ class SEANetDecoder(StreamingContainer):
245
+ """SEANet decoder.
246
+
247
+ Args:
248
+ channels (int): Audio channels.
249
+ dimension (int): Intermediate representation dimension.
250
+ n_filters (int): Base width for the model.
251
+ n_residual_layers (int): nb of residual layers.
252
+ ratios (Sequence[int]): kernel size and stride ratios.
253
+ activation (str): Activation function.
254
+ activation_params (dict): Parameters to provide to the activation function.
255
+ final_activation (str): Final activation function after all convolutions.
256
+ final_activation_params (dict): Parameters to provide to the activation function.
257
+ norm (str): Normalization method.
258
+ norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
259
+ kernel_size (int): Kernel size for the initial convolution.
260
+ last_kernel_size (int): Kernel size for the initial convolution.
261
+ residual_kernel_size (int): Kernel size for the residual layers.
262
+ dilation_base (int): How much to increase the dilation with each layer.
263
+ causal (bool): Whether to use fully causal convolution.
264
+ pad_mode (str): Padding mode for the convolutions.
265
+ true_skip (bool): Whether to use true skip connection or a simple.
266
+ (streamable) convolution as the skip connection in the residual network blocks.
267
+ compress (int): Reduced dimensionality in residual branches (from Demucs v3).
268
+ disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
269
+ For the decoder, it corresponds to the N last blocks.
270
+ trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
271
+ If equal to 1.0, it means that all the trimming is done at the right.
272
+ """
273
+
274
+ def __init__(
275
+ self,
276
+ channels: int = 1,
277
+ dimension: int = 128,
278
+ n_filters: int = 32,
279
+ n_residual_layers: int = 3,
280
+ ratios: tp.List[int] = [8, 5, 4, 2],
281
+ activation: str = "ELU",
282
+ activation_params: dict = {"alpha": 1.0},
283
+ final_activation: tp.Optional[str] = None,
284
+ final_activation_params: tp.Optional[dict] = None,
285
+ norm: str = "none",
286
+ norm_params: tp.Dict[str, tp.Any] = {},
287
+ kernel_size: int = 7,
288
+ last_kernel_size: int = 7,
289
+ residual_kernel_size: int = 3,
290
+ dilation_base: int = 2,
291
+ causal: bool = False,
292
+ pad_mode: str = "reflect",
293
+ true_skip: bool = True,
294
+ compress: int = 2,
295
+ disable_norm_outer_blocks: int = 0,
296
+ trim_right_ratio: float = 1.0,
297
+ ):
298
+ super().__init__()
299
+ self.dimension = dimension
300
+ self.channels = channels
301
+ self.n_filters = n_filters
302
+ self.ratios = ratios
303
+ del ratios
304
+ self.n_residual_layers = n_residual_layers
305
+ self.hop_length = int(np.prod(self.ratios))
306
+ self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
307
+ self.disable_norm_outer_blocks = disable_norm_outer_blocks
308
+ assert (
309
+ self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
310
+ ), (
311
+ "Number of blocks for which to disable norm is invalid."
312
+ "It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
313
+ )
314
+
315
+ act = getattr(nn, activation)
316
+ mult = int(2 ** len(self.ratios))
317
+ model: tp.List[nn.Module] = [
318
+ StreamingConv1d(
319
+ dimension,
320
+ mult * n_filters,
321
+ kernel_size,
322
+ norm=(
323
+ "none" if self.disable_norm_outer_blocks == self.n_blocks else norm
324
+ ),
325
+ norm_kwargs=norm_params,
326
+ causal=causal,
327
+ pad_mode=pad_mode,
328
+ )
329
+ ]
330
+
331
+ # Upsample to raw audio scale
332
+ for i, ratio in enumerate(self.ratios):
333
+ block_norm = (
334
+ "none"
335
+ if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
336
+ else norm
337
+ )
338
+ # Add upsampling layers
339
+ model += [
340
+ act(**activation_params),
341
+ StreamingConvTranspose1d(
342
+ mult * n_filters,
343
+ mult * n_filters // 2,
344
+ kernel_size=ratio * 2,
345
+ stride=ratio,
346
+ norm=block_norm,
347
+ norm_kwargs=norm_params,
348
+ causal=causal,
349
+ trim_right_ratio=trim_right_ratio,
350
+ ),
351
+ ]
352
+ # Add residual layers
353
+ for j in range(n_residual_layers):
354
+ model += [
355
+ SEANetResnetBlock(
356
+ mult * n_filters // 2,
357
+ kernel_sizes=[residual_kernel_size, 1],
358
+ dilations=[dilation_base**j, 1],
359
+ activation=activation,
360
+ activation_params=activation_params,
361
+ norm=block_norm,
362
+ norm_params=norm_params,
363
+ causal=causal,
364
+ pad_mode=pad_mode,
365
+ compress=compress,
366
+ true_skip=true_skip,
367
+ )
368
+ ]
369
+
370
+ mult //= 2
371
+
372
+ # Add final layers
373
+ model += [
374
+ act(**activation_params),
375
+ StreamingConv1d(
376
+ n_filters,
377
+ channels,
378
+ last_kernel_size,
379
+ norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
380
+ norm_kwargs=norm_params,
381
+ causal=causal,
382
+ pad_mode=pad_mode,
383
+ ),
384
+ ]
385
+ # Add optional final activation to decoder (eg. tanh)
386
+ if final_activation is not None:
387
+ final_act = getattr(nn, final_activation)
388
+ final_activation_params = final_activation_params or {}
389
+ model += [final_act(**final_activation_params)]
390
+ self.model = nn.Sequential(*model)
391
+
392
+ @torch_compile_lazy
393
+ def forward(self, z):
394
+ y = self.model(z)
395
+ return y
moshi/modules/streaming.py ADDED
@@ -0,0 +1,363 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ """
12
+ Streaming module API that should be implemented by all Streaming components,
13
+ """
14
+
15
+ import abc
16
+ from contextlib import contextmanager
17
+ from dataclasses import dataclass
18
+ import itertools
19
+ import math
20
+ import typing as tp
21
+ from torch import nn
22
+ import torch
23
+
24
+
25
+ class Resetable(tp.Protocol):
26
+ def reset(self) -> None:
27
+ pass
28
+
29
+
30
+ State = tp.TypeVar("State", bound=Resetable)
31
+
32
+
33
+ class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]):
34
+ """Common API for streaming components.
35
+
36
+ Each streaming component has a streaming state, which is just a dict[str, Tensor].
37
+ By convention, the first dim of each tensor must be the batch size.
38
+ Don't use dots in the key names, as this would clash with submodules
39
+ (like in state_dict).
40
+
41
+ If `self._is_streaming` is True, the component should use and remember
42
+ the proper state inside `self._streaming_state`.
43
+
44
+ To set a streaming component in streaming state, use
45
+
46
+ with module.streaming():
47
+ ...
48
+
49
+ This will automatically reset the streaming state when exiting the context manager.
50
+ This also automatically propagates to all streaming children module.
51
+
52
+ Some module might also implement the `StreamingModule.flush` method, although
53
+ this one is trickier, as all parents module must be StreamingModule and implement
54
+ it as well for it to work properly. See `StreamingSequential` after.
55
+ """
56
+
57
+ def __init__(self) -> None:
58
+ super().__init__()
59
+ self._streaming_state: State | None = None
60
+ self._streaming_propagate: bool = True
61
+
62
+ @property
63
+ def is_streaming(self):
64
+ return self._streaming_state is not None
65
+
66
+ def set_streaming_propagate(self, streaming_propagate: bool):
67
+ self._streaming_propagate = streaming_propagate
68
+
69
+ def _apply_named_streaming(self, fn: tp.Any):
70
+ def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
71
+ propagate = True
72
+ if isinstance(module, StreamingModule):
73
+ if module._streaming_propagate:
74
+ fn(prefix, module)
75
+ else:
76
+ propagate = False
77
+ if not recurse:
78
+ return
79
+ if propagate:
80
+ for name, child in module.named_children():
81
+ _handle_module(prefix + "." + name, child)
82
+
83
+ _handle_module("", self, recurse=False)
84
+ for name, child in self.named_children():
85
+ _handle_module(name, child)
86
+
87
+ def _start_streaming(self, batch_size: int):
88
+ def _start_streaming(name: str, module: StreamingModule):
89
+ module._streaming_state = module._init_streaming_state(batch_size)
90
+
91
+ self._apply_named_streaming(_start_streaming)
92
+
93
+ def _stop_streaming(self):
94
+ def _stop_streaming(name: str, module: StreamingModule):
95
+ module._streaming_state = None
96
+
97
+ self._apply_named_streaming(_stop_streaming)
98
+
99
+ @abc.abstractmethod
100
+ def _init_streaming_state(self, batch_size: int) -> State: ...
101
+
102
+ def streaming_forever(self, batch_size: int):
103
+ self._start_streaming(batch_size)
104
+
105
+ @contextmanager
106
+ def streaming(self, batch_size: int):
107
+ """Context manager to enter streaming mode. Reset streaming state on exit."""
108
+
109
+ self._start_streaming(batch_size)
110
+ try:
111
+ yield
112
+ finally:
113
+ self._stop_streaming()
114
+
115
+ def reset_streaming(self):
116
+ """Reset the streaming state."""
117
+
118
+ def _reset(name: str, module: StreamingModule):
119
+ state = module._streaming_state
120
+ if state is None:
121
+ raise ValueError(
122
+ f"Trying to reset streaming, but {name} wasn't streaming."
123
+ )
124
+ state.reset()
125
+
126
+ self._apply_named_streaming(_reset)
127
+
128
+ def get_streaming_state(self) -> dict[str, tp.Any]:
129
+ """Return the complete streaming state, including that of sub-modules."""
130
+ state: dict[str, tp.Any] = {}
131
+
132
+ def _add(name: str, module: StreamingModule):
133
+ state[name] = module._streaming_state
134
+
135
+ self._apply_named_streaming(_add)
136
+ return state
137
+
138
+ def set_streaming_state(self, state: dict[str, tp.Any]):
139
+ """Set the streaming state, including that of sub-modules."""
140
+ state = dict(state)
141
+
142
+ def _set(name: str, module: StreamingModule):
143
+ if name in state:
144
+ module._streaming_state = state[name]
145
+ state.pop(name)
146
+ else:
147
+ raise RuntimeError(f"Expected to find a streaming state for {name}.")
148
+
149
+ self._apply_named_streaming(_set)
150
+ if state:
151
+ raise RuntimeError(f"Some states were not consumed: {list(state.keys())}")
152
+
153
+
154
+ @dataclass
155
+ class _NullState:
156
+ pass
157
+
158
+ def reset(self) -> None:
159
+ pass
160
+
161
+
162
+ class StreamingContainer(StreamingModule[_NullState]):
163
+ def _init_streaming_state(self, batch_size: int) -> _NullState:
164
+ return _NullState()
165
+
166
+
167
+ @dataclass
168
+ class _StreamingAddState:
169
+ previous_x: torch.Tensor | None = None
170
+ previous_y: torch.Tensor | None = None
171
+
172
+ def reset(self):
173
+ self.previous_x = None
174
+ self.previous_y = None
175
+
176
+
177
+ class StreamingAdd(StreamingModule[_StreamingAddState]):
178
+ def _init_streaming_state(self, batch_size: int) -> _StreamingAddState:
179
+ return _StreamingAddState()
180
+
181
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
182
+ if self._streaming_state is None:
183
+ return x + y
184
+ else:
185
+ prev_x = self._streaming_state.previous_x
186
+ prev_y = self._streaming_state.previous_y
187
+ if prev_x is not None:
188
+ x = torch.cat([prev_x, x], dim=-1)
189
+ if prev_y is not None:
190
+ y = torch.cat([prev_y, y], dim=-1)
191
+ m_l = min(x.shape[-1], y.shape[-1])
192
+ self._streaming_state.previous_x = x[..., m_l:]
193
+ self._streaming_state.previous_y = y[..., m_l:]
194
+ return x[..., :m_l] + y[..., :m_l]
195
+
196
+
197
+ @dataclass
198
+ class _StreamingConvState:
199
+ previous: torch.Tensor | None = None
200
+
201
+ def reset(self):
202
+ self.previous = None
203
+
204
+
205
+ class RawStreamingConv1d(nn.Conv1d, StreamingModule[_StreamingConvState]):
206
+ def __init__(self, *args, **kwargs):
207
+ super().__init__(*args, **kwargs)
208
+ assert self.padding[0] == 0, "Padding should be handled outside."
209
+ assert (
210
+ self.stride[0] <= self.kernel_size[0]
211
+ ), "stride must be less than kernel_size."
212
+
213
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvState:
214
+ return _StreamingConvState()
215
+
216
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
217
+ stride = self.stride[0]
218
+ # Effective kernel size accounting for dilation.
219
+ kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1
220
+ if self._streaming_state is None:
221
+ return super().forward(input)
222
+ else:
223
+ # Due to the potential overlap, we might have some cache of the previous time steps.
224
+ previous = self._streaming_state.previous
225
+ if previous is not None:
226
+ input = torch.cat([previous, input], dim=-1)
227
+ B, C, T = input.shape
228
+ # We now compute the number of full convolution frames, i.e. the frames
229
+ # that are ready to be computed.
230
+ num_frames = max(0, int(math.floor((T - kernel) / stride) + 1))
231
+ offset = num_frames * stride
232
+ # We will compute `num_frames` outputs, and we are advancing by `stride`
233
+ # for each of the frame, so we know the data before `stride * num_frames`
234
+ # will never be used again.
235
+ self._streaming_state.previous = input[..., offset:]
236
+ if num_frames > 0:
237
+ input_length = (num_frames - 1) * stride + kernel
238
+ out = super().forward(input[..., :input_length])
239
+ else:
240
+ # Not enough data as this point to output some new frames.
241
+ out = torch.empty(
242
+ B, self.out_channels, 0, device=input.device, dtype=input.dtype
243
+ )
244
+ return out
245
+
246
+
247
+ @dataclass
248
+ class _StreamingConvTrState:
249
+ partial: torch.Tensor | None = None
250
+
251
+ def reset(self):
252
+ self.partial = None
253
+
254
+
255
+ class RawStreamingConvTranspose1d(
256
+ nn.ConvTranspose1d, StreamingModule[_StreamingConvTrState]
257
+ ):
258
+ def __init__(self, *args, **kwargs):
259
+ super().__init__(*args, **kwargs)
260
+ assert self.padding[0] == 0, "Padding should be handled outside."
261
+ assert self.dilation[0] == 1, "No dilation for now"
262
+ assert (
263
+ self.stride[0] <= self.kernel_size[0]
264
+ ), "stride must be less than kernel_size."
265
+ assert self.output_padding[0] == 0, "Output padding not supported."
266
+
267
+ def _init_streaming_state(self, batch_size: int) -> _StreamingConvTrState:
268
+ return _StreamingConvTrState()
269
+
270
+ def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
271
+ B, C, T = x.shape
272
+ stride = self.stride[0]
273
+ kernel = self.kernel_size[0]
274
+ if self._streaming_state is None:
275
+ return super().forward(x)
276
+ else:
277
+ if T == 0:
278
+ return torch.empty(
279
+ B, self.out_channels, 0, device=x.device, dtype=x.dtype
280
+ )
281
+ out = super().forward(x)
282
+ OT = out.shape[-1]
283
+ partial = self._streaming_state.partial
284
+ if partial is not None:
285
+ # Due to the potential overlap, the rightmost output of the conv transpose is not
286
+ # ready to be output, as it will receive contributions from the next input frames.
287
+ # Here we recover those `partial` output frames. We know that the first time step
288
+ # of the `partial` tensor corresponds to the first time step of `out` as anything
289
+ # coming before the first time step of `out` would have been already flushed.
290
+ PT = partial.shape[-1]
291
+ if self.bias is not None:
292
+ out[..., :PT] += partial - self.bias[:, None]
293
+ else:
294
+ out[..., :PT] += partial
295
+ # The input is T, the output is S * (T - 1) + K.
296
+ # The offset of the left of the next frame will be S * T
297
+ # so everything between 0 and S * T is ready to be output, and we need
298
+ # to keep in the internal state everything beyond that, i.e. S (T - 1) + K - S T = K - S
299
+ invalid_steps = kernel - stride
300
+ partial = out[..., OT - invalid_steps :]
301
+ out = out[..., : OT - invalid_steps]
302
+ self._streaming_state.partial = partial
303
+ return out
304
+
305
+
306
+ def test():
307
+ torch.manual_seed(1234)
308
+ device = "cpu"
309
+ if torch.cuda.is_available():
310
+ # Avoid the cuda optimizations that would take place on single precision
311
+ # floats for convolutions.
312
+ torch.backends.cudnn.enabled = True
313
+ torch.backends.cudnn.benchmark = False
314
+ torch.backends.cudnn.deterministic = True
315
+ torch.backends.cuda.matmul.allow_tf32 = False
316
+ torch.backends.cudnn.allow_tf32 = False
317
+ device = "cuda:0"
318
+
319
+ kernel_sizes = [1, 3, 4, 8, 15, 16]
320
+ strides = [1, 2, 3, 4, 5, 6, 7, 8, 9]
321
+ chin = 6
322
+ chout = 12
323
+
324
+ for kernel, stride in itertools.product(kernel_sizes, strides):
325
+ if stride > kernel:
326
+ continue
327
+ conv = RawStreamingConv1d(chin, chout, kernel, stride).to(device)
328
+ convtr = RawStreamingConvTranspose1d(chout, chin, kernel, stride).to(device)
329
+
330
+ for length in [4, 8, 32, 54, 65, 128, 1043]:
331
+ print(f"ksize {kernel} strides {stride} len {length}")
332
+ if length < kernel:
333
+ continue
334
+ batch_size = 3
335
+ x = torch.randn(batch_size, chin, length).to(device)
336
+ y = conv(x)
337
+ z = convtr(y)
338
+ for chunk_size in [1, 3, 5, 8]:
339
+ ys = []
340
+ zs = []
341
+ with conv.streaming(batch_size), convtr.streaming(batch_size):
342
+ for offset in range(0, length, chunk_size):
343
+ chunk = x[..., offset : offset + chunk_size]
344
+ ys.append(conv(chunk))
345
+ zs.append(convtr(ys[-1]))
346
+ y_stream = torch.cat(ys, dim=-1)
347
+ z_stream = torch.cat(zs, dim=-1)
348
+ y = y[..., : y_stream.shape[-1]]
349
+ z = z[..., : z_stream.shape[-1]]
350
+ assert y.shape == y_stream.shape, (y.shape, y_stream.shape)
351
+ delta = (y_stream - y).norm() / y.norm()
352
+ assert delta <= 1e-6, delta
353
+ num_frames = int((length - kernel) / stride) + 1
354
+ assert num_frames == y_stream.shape[-1]
355
+
356
+ assert z.shape == z_stream.shape, (z.shape, z_stream.shape)
357
+ delta = (z_stream - z).norm() / z.norm()
358
+ assert delta <= 1e-6, (delta, (z_stream - z).abs().mean(dim=(0, 1)))
359
+
360
+
361
+ if __name__ == "__main__":
362
+ with torch.no_grad():
363
+ test()
moshi/modules/transformer.py ADDED
@@ -0,0 +1,750 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ """
6
+ Transformer model, with streaming support, + CUDA Graphable.
7
+ Optimized for inference.
8
+
9
+ See `StreamingTransformer` for more information.
10
+ """
11
+
12
+ from contextlib import ExitStack
13
+ from dataclasses import dataclass
14
+ import typing as tp
15
+
16
+ from einops import rearrange
17
+ import torch
18
+ import torch.nn as nn
19
+ from torch.nn import functional as F
20
+
21
+ from ..utils.compile import no_compile
22
+ from .gating import make_gating
23
+ from .rope import RotaryEmbedding
24
+ from .streaming import StreamingModule, StreamingContainer
25
+
26
+
27
+ class LayerNormF32(nn.LayerNorm):
28
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
29
+ x_f32 = input.float()
30
+ out_f32 = super().forward(x_f32)
31
+ return out_f32.to(input.dtype)
32
+
33
+
34
+ def _rms_norm(
35
+ x: torch.Tensor,
36
+ alpha: torch.Tensor,
37
+ dtype: tp.Optional[torch.dtype],
38
+ eps: float,
39
+ ):
40
+ assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}"
41
+ x_dtype = x.dtype
42
+ if dtype is not None:
43
+ x = x.to(dtype)
44
+ var = eps + torch.mean(x**2, dim=2, keepdim=True)
45
+ y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
46
+ return y
47
+
48
+
49
+ class RMSNorm(nn.Module):
50
+ def __init__(
51
+ self,
52
+ dim: int,
53
+ eps: float = 1e-5,
54
+ dtype: tp.Optional[torch.dtype] = None,
55
+ device=None,
56
+ ):
57
+ super().__init__()
58
+ self.eps = eps
59
+ self.dtype = dtype
60
+ self.alpha = nn.Parameter(
61
+ torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)
62
+ )
63
+
64
+ def forward(self, x: torch.Tensor):
65
+ return _rms_norm(x, self.alpha, self.dtype, self.eps)
66
+
67
+
68
+ class LayerScale(nn.Module):
69
+ """Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
70
+ This rescales diagonally the residual outputs close to 0, with a learnt scale.
71
+
72
+ Args:
73
+ channels (int): Number of channels.
74
+ init (float): Initial scale.
75
+ channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
76
+ device (torch.device or str, optional): Device on which to initialize the module.
77
+ dtype (torch.dtype, optional): dtype to use to initialize the module.
78
+ """
79
+
80
+ def __init__(
81
+ self,
82
+ channels: int,
83
+ init: float = 1e-4,
84
+ channel_last: bool = True,
85
+ device=None,
86
+ dtype=None,
87
+ ):
88
+ super().__init__()
89
+ self.channel_last = channel_last
90
+ self.scale = nn.Parameter(
91
+ torch.full(
92
+ (channels,), init, requires_grad=True, device=device, dtype=dtype
93
+ )
94
+ )
95
+
96
+ def forward(self, x: torch.Tensor):
97
+ if self.channel_last:
98
+ return self.scale * x
99
+ else:
100
+ return self.scale[:, None] * x
101
+
102
+
103
+ def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
104
+ """Create normalization module for transformer encoder layer.
105
+
106
+ Args:
107
+ norm_type (str): Normalization method.
108
+ dim (int): Dimension of the normalized layer.
109
+ **kwargs (dict): Additional parameters for normalization layer.
110
+ Returns:
111
+ nn.Module: Normalization module.
112
+ """
113
+ if norm_type == "layer_norm":
114
+ return nn.LayerNorm(dim, eps=1e-5, **kwargs)
115
+ elif norm_type == "layer_norm_f32":
116
+ kwargs.pop("dtype", None)
117
+ return LayerNormF32(dim, eps=1e-8, **kwargs)
118
+ elif norm_type in {"rms_norm"}:
119
+ return RMSNorm(dim, eps=1e-5, **kwargs)
120
+ elif norm_type in {"rms_norm_f32"}:
121
+ kwargs.pop("dtype", None)
122
+ return RMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
123
+ else:
124
+ raise ValueError(f"Unknown norm type: {norm_type}")
125
+
126
+
127
+ def create_sin_embedding(
128
+ positions: torch.Tensor,
129
+ dim: int,
130
+ max_period: float = 10000,
131
+ dtype: torch.dtype = torch.float32,
132
+ ) -> torch.Tensor:
133
+ """Create sinusoidal positional embedding, with shape `[B, T, C]`.
134
+
135
+ Args:
136
+ positions (torch.Tensor): LongTensor of positions.
137
+ dim (int): Dimension of the embedding.
138
+ max_period (float): Maximum period of the cosine/sine functions.
139
+ dtype (torch.dtype or str): dtype to use to generate the embedding.
140
+ Returns:
141
+ torch.Tensor: Sinusoidal positional embedding.
142
+ """
143
+ # We aim for BTC format
144
+ assert dim % 2 == 0
145
+ half_dim = dim // 2
146
+ positions = positions.to(dtype)
147
+ adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
148
+ max_period_tensor = torch.full(
149
+ [], max_period, device=positions.device, dtype=dtype
150
+ ) # avoid sync point
151
+ phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
152
+ return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
153
+
154
+
155
+ def multi_linear(
156
+ num_linear: int,
157
+ weight: torch.Tensor,
158
+ x: torch.Tensor,
159
+ offset: int,
160
+ ):
161
+ """Utility to apply a multi linear layer to the given input. A multi linear layer
162
+ applies a different set of weight for each time step.
163
+
164
+ Args:
165
+ num_linear (int): Number of possible time steps and so number of linears.
166
+ weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
167
+ x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
168
+ offset (int): offset for the current time step, in particular for decoding, with
169
+ time steps provided one by one.
170
+ """
171
+ B, T, C = x.shape
172
+ ys = []
173
+ chout, chin = weight.shape
174
+ weight = weight.view(num_linear, -1, chin)
175
+ for t in range(T):
176
+ y = F.linear(x[:, t], weight[t + offset])
177
+ ys.append(y)
178
+ out = torch.stack(ys, 1)
179
+ return out
180
+
181
+
182
+ def set_attention_context(model: nn.Module, context: tp.Optional[int] = None) -> None:
183
+ """Deactivates or changes the context span (in time steps) in a model.
184
+ Args:
185
+ model (nn.Module): model over which to look for attentions.
186
+ context (int or None): new temporary context value.
187
+
188
+ ..Note:: this is not a context manager but a plain function changing the context forever.
189
+ Initially, it was a context manager, but that led to interesting bugs when using
190
+ activation checkpointing, with the context being inconsistent between the forward
191
+ and backward.
192
+ """
193
+ for module in model.modules():
194
+ if isinstance(module, StreamingMultiheadAttention):
195
+ module.context = context
196
+
197
+
198
+ class KVCacheResult(tp.NamedTuple):
199
+ keys: torch.Tensor
200
+ values: torch.Tensor
201
+ positions: torch.Tensor
202
+
203
+ @staticmethod
204
+ def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
205
+ B, H, T, D = keys.shape
206
+ assert tuple(values.shape[:-1]) == (B, H, T)
207
+ positions = torch.arange(T, device=keys.device, dtype=torch.long)
208
+ return KVCacheResult(keys, values, positions)
209
+
210
+
211
+ class RingKVCache:
212
+ """Efficient streaming KVCache to be compatible with Cuda Graph.
213
+
214
+ Args:
215
+ batch_size (int): Batch size.
216
+ num_heads (int): Number of heads in the attention.
217
+ dim_per_head (int): Dimension per head.
218
+ device (torch.device): Device on which to initialize the cache.
219
+ dtype (torch.dtype): dtype to use for the cache.
220
+ """
221
+
222
+ def __init__(
223
+ self,
224
+ batch_size: int,
225
+ num_heads: int,
226
+ dim_per_head: int,
227
+ capacity: int,
228
+ device: torch.device = torch.device("cuda"),
229
+ dtype: torch.dtype = torch.bfloat16,
230
+ ):
231
+ self.capacity = capacity
232
+ self.cache = torch.zeros(
233
+ (2, batch_size, num_heads, capacity, dim_per_head),
234
+ device=device,
235
+ dtype=dtype,
236
+ )
237
+ self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
238
+
239
+ def reset(self):
240
+ self.end_offset.zero_()
241
+
242
+ def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
243
+ assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
244
+ B, H, T, D = k.shape
245
+ indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
246
+ indexes = indexes % self.capacity
247
+ self.cache[0].index_copy_(2, indexes, k)
248
+ self.cache[1].index_copy_(2, indexes, v)
249
+ self.end_offset.add_(T)
250
+
251
+ keys = self.cache[0]
252
+ values = self.cache[1]
253
+
254
+ indexes = torch.arange(
255
+ self.capacity, device=self.end_offset.device, dtype=torch.long
256
+ )
257
+ invalid = indexes >= self.end_offset
258
+
259
+ end_index = self.end_offset % self.capacity
260
+ delta = indexes - end_index
261
+
262
+ # If last key is for step S, and capacity is C, last key was written at index S % C.
263
+ # then end_offset = S + 1, and end_index = (S + 1) % C.
264
+ # Then for index = (S % C), delta = -1, and the next code gives us:
265
+ # position(index) = (S + 1) - 1 = S, all good.
266
+ # Now the time step at end_offset is actually the oldest in the KVCache, e.g., its
267
+ # position should be (S - self.capacity + 1).
268
+ # The following code gives us:
269
+ # position(index + 1) = S + 1 + 0 - self.capacity.
270
+
271
+ positions = torch.where(
272
+ delta <= 0,
273
+ self.end_offset + delta,
274
+ self.end_offset + delta - self.capacity,
275
+ )
276
+ positions = torch.where(invalid, torch.full_like(positions, -1), positions)
277
+
278
+ return KVCacheResult(keys, values, positions)
279
+
280
+
281
+ @dataclass
282
+ class _MHAState:
283
+ kv_cache: RingKVCache
284
+ offset: torch.Tensor
285
+ offset_cpu: int
286
+
287
+ def reset(self):
288
+ self.kv_cache.reset()
289
+ self.offset.zero_()
290
+ self.offset_cpu = 0
291
+
292
+
293
+ class StreamingMultiheadAttention(StreamingModule[_MHAState]):
294
+ """Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
295
+
296
+ Args:
297
+ embed_dim (int): Dimension to project to.
298
+ num_heads (int): Number of heads.
299
+ causal (bool): Causal mask applied automatically.
300
+ context (int, optional): Number of time steps the attention can access to.
301
+ When causal, can access `context` time steps into the past, and when non causal,
302
+ can access `context // 2` steps in the past, and the same in the future.
303
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
304
+ weights_per_step (int): use different weights per time step. If non zero, should correspond to the
305
+ number of possible time steps.
306
+ device (torch.device, optional): Device on which to initialize.
307
+ dtype (torch.dtype, optional): dtype to use.
308
+ """
309
+
310
+ _fsdp_final = True
311
+
312
+ def __init__(
313
+ self,
314
+ embed_dim: int,
315
+ num_heads: int,
316
+ causal: bool = False,
317
+ context: tp.Optional[int] = None,
318
+ rope: tp.Optional[RotaryEmbedding] = None,
319
+ weights_per_step: int = 0,
320
+ device=None,
321
+ dtype=None,
322
+ ):
323
+ super().__init__()
324
+ factory_kwargs = {"device": device, "dtype": dtype}
325
+
326
+ self.embed_dim = embed_dim
327
+ self.causal = causal
328
+ self.context = context
329
+ self.rope = rope
330
+ self.num_heads = num_heads
331
+
332
+ out_dim = embed_dim
333
+ out_dim = 3 * embed_dim
334
+ mult = 1
335
+ self.weights_per_step = weights_per_step
336
+ if weights_per_step:
337
+ mult = weights_per_step
338
+ in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
339
+ # We try to follow the default PyTorch MHA convention, to easily compare results.
340
+ self.in_proj_weight = in_proj.weight
341
+ self.in_proj_bias = in_proj.bias
342
+ self.out_proj = nn.Linear(
343
+ embed_dim, mult * embed_dim, bias=False, **factory_kwargs
344
+ )
345
+
346
+ def _init_streaming_state(self, batch_size: int) -> _MHAState:
347
+ if self.context is None:
348
+ if self.weights_per_step:
349
+ capacity = self.weights_per_step
350
+ else:
351
+ raise RuntimeError(
352
+ "Cannot create a streaming KVCache without a context to estimate capacity."
353
+ )
354
+ else:
355
+ capacity = self.context
356
+ device = self.in_proj_weight.device
357
+ # TODO: the following estimation will not work great with FSDP.
358
+ dtype = self.in_proj_weight.dtype
359
+ dim_per_head = self.embed_dim // self.num_heads
360
+ kv_cache = RingKVCache(
361
+ batch_size, self.num_heads, dim_per_head, capacity, device, dtype
362
+ )
363
+ return _MHAState(
364
+ kv_cache,
365
+ offset=torch.zeros(1, device=device, dtype=torch.long),
366
+ offset_cpu=0,
367
+ )
368
+
369
+ def _complete_kv(self, k, v) -> KVCacheResult:
370
+ state = self._streaming_state
371
+ if state is None:
372
+ return KVCacheResult.from_kv(k, v)
373
+ else:
374
+ return state.kv_cache.complete(k, v)
375
+
376
+ def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
377
+ state = self._streaming_state
378
+ T = query.shape[1]
379
+
380
+ if state is None:
381
+ offset = torch.zeros(1, device=query.device, dtype=torch.long)
382
+ offset_cpu = 0
383
+ else:
384
+ assert self.causal, "Streaming only available for causal"
385
+ offset = state.offset
386
+ offset_cpu = state.offset_cpu
387
+
388
+ if self.weights_per_step:
389
+ projected = multi_linear(
390
+ self.weights_per_step, self.in_proj_weight, query, offset_cpu
391
+ )
392
+ else:
393
+ projected = nn.functional.linear(query, self.in_proj_weight)
394
+ q, k, v = rearrange(
395
+ projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
396
+ )
397
+
398
+ if self.rope:
399
+ q, k = self.rope(q, k, offset, time_before_heads=False)
400
+
401
+ k, v, pos_k = self._complete_kv(k, v)
402
+ if self.causal:
403
+ pos_k = pos_k.view(1, -1)
404
+ pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view(
405
+ -1, 1
406
+ )
407
+ delta = pos_q - pos_k
408
+ attn_bias = (pos_k >= 0) & (delta >= 0)
409
+ if self.context is not None:
410
+ attn_bias = attn_bias & (delta < self.context)
411
+ else:
412
+ attn_bias = None
413
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
414
+
415
+ x = rearrange(x, "b h t d -> b t (h d)")
416
+ if self.weights_per_step:
417
+ x = multi_linear(self.weights_per_step, self.out_proj.weight, x, offset_cpu)
418
+ else:
419
+ x = self.out_proj(x)
420
+ if state is not None:
421
+ state.offset.add_(T)
422
+ state.offset_cpu += T
423
+ return x
424
+
425
+
426
+ @dataclass
427
+ class _LayerState:
428
+ offset_cpu: int
429
+
430
+ def reset(self):
431
+ self.offset_cpu = 0
432
+
433
+
434
+ class StreamingTransformerLayer(StreamingModule[_LayerState]):
435
+ """TransformerLayer with Streaming / Causal support.
436
+
437
+ Args:
438
+ d_model (int): Dimension of the data.
439
+ num_heads (int): Number of heads.
440
+ dim_feedforward (int): Intermediate dimension of FF module.
441
+ causal (bool): Causal mask applied automatically.
442
+ context (int, optional): Receptive field for the causal mask, infinite if None.
443
+ custom (bool): Use custom MHA implementation, for testing / benchmarking.
444
+ rope (`RotaryEmbedding`, optional): Rope embedding to use.
445
+ norm (str): Normalization to use. Currently, only 'layer_norm' is supported.
446
+ layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale.
447
+ gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc.
448
+ weights_per_step (int): use different weights per time step. If non zero, should correspond to the
449
+ number of possible time steps.
450
+ skip_self_attn: If true, skips the self attention module and the norm
451
+ device (torch.device, optional): Device on which to initialize.
452
+ dtype (torch.dtype, optional): dtype to use.
453
+ """
454
+
455
+ _fsdp_final = True
456
+
457
+ def __init__(
458
+ self,
459
+ d_model: int,
460
+ num_heads: int,
461
+ dim_feedforward: int | list[int] = 2048,
462
+ causal: bool = False,
463
+ context: tp.Optional[int] = None,
464
+ rope: tp.Optional[RotaryEmbedding] = None,
465
+ norm: str = "layer_norm",
466
+ layer_scale: tp.Optional[float] = None,
467
+ gating: str = "none",
468
+ weights_per_step: int = 0,
469
+ activation=F.gelu,
470
+ skip_self_attn: bool = False,
471
+ device=None,
472
+ dtype=None,
473
+ ):
474
+ super().__init__()
475
+ factory_kwargs = {"device": device, "dtype": dtype}
476
+ # Redefine self_attn to our streaming multi-head attention
477
+ attn_kwargs: tp.Dict[str, tp.Any] = {
478
+ "embed_dim": d_model,
479
+ "num_heads": num_heads,
480
+ }
481
+ if not skip_self_attn:
482
+ self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
483
+ causal=causal,
484
+ context=context,
485
+ rope=rope,
486
+ weights_per_step=weights_per_step,
487
+ **attn_kwargs, # type: ignore
488
+ **factory_kwargs, # type: ignore
489
+ ) # type: ignore
490
+ self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
491
+ self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
492
+ # Redefine feedforward layers to expose bias parameter
493
+ self.weights_per_step = weights_per_step
494
+ self.gating: tp.Optional[nn.Module] = None
495
+ self.linear1: tp.Optional[nn.Module] = None
496
+ self.linear2: tp.Optional[nn.Module] = None
497
+ self.activation = activation
498
+ self.skip_self_attn = skip_self_attn
499
+
500
+ if isinstance(dim_feedforward, list):
501
+ assert dim_feedforward
502
+ assert len(dim_feedforward) == weights_per_step, (
503
+ "Length of dim_feedforward must match weights_per_step,"
504
+ f" got {len(dim_feedforward)} != {weights_per_step}"
505
+ )
506
+ if gating == "none":
507
+ assert (
508
+ not weights_per_step
509
+ ), "weights_per_step without gating not supported for now."
510
+ assert not isinstance(
511
+ dim_feedforward, list
512
+ ), "List dim_feedforward without gating not supported for now."
513
+ self.linear1 = nn.Linear(
514
+ d_model, dim_feedforward, bias=False, **factory_kwargs
515
+ )
516
+ self.linear2 = nn.Linear(
517
+ dim_feedforward, d_model, bias=False, **factory_kwargs
518
+ )
519
+ else:
520
+ self.linear1 = None
521
+ self.linear2 = None
522
+ if weights_per_step:
523
+ if isinstance(dim_feedforward, int):
524
+ dim_feedforward = [dim_feedforward] * weights_per_step
525
+ assert isinstance(dim_feedforward, list), dim_feedforward
526
+ self.gating = nn.ModuleList(
527
+ [
528
+ make_gating(gating, d_model, dim, **factory_kwargs)
529
+ for dim in dim_feedforward
530
+ ]
531
+ )
532
+ else:
533
+ assert isinstance(dim_feedforward, int)
534
+ self.gating = make_gating(
535
+ gating, d_model, dim_feedforward, **factory_kwargs
536
+ )
537
+
538
+ self.layer_scale_1: nn.Module
539
+ self.layer_scale_2: nn.Module
540
+ if layer_scale is None:
541
+ self.layer_scale_1 = nn.Identity()
542
+ self.layer_scale_2 = nn.Identity()
543
+ else:
544
+ self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
545
+ self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
546
+
547
+ def _init_streaming_state(self, batch_size: int) -> _LayerState:
548
+ return _LayerState(offset_cpu=0)
549
+
550
+ # feed forward block
551
+ def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
552
+ state = self._streaming_state
553
+ offset = 0
554
+ if state is not None:
555
+ offset = state.offset_cpu
556
+ x_orig = x
557
+ x = self.norm2(x)
558
+ if self.gating is None:
559
+ assert self.linear1 is not None
560
+ assert self.linear2 is not None
561
+ update = self.linear2(self.activation(self.linear1(x)))
562
+ else:
563
+ if self.weights_per_step:
564
+ assert isinstance(self.gating, nn.ModuleList)
565
+ B, T, D = x.shape
566
+ ys = []
567
+ for t in range(T):
568
+ y = self.gating[offset + t](x[:, t : t + 1])
569
+ ys.append(y)
570
+ update = torch.cat(ys, dim=1)
571
+ else:
572
+ update = self.gating(x)
573
+ return x_orig + self.layer_scale_2(update)
574
+
575
+ def _sa_block(self, x: torch.Tensor):
576
+ if self.skip_self_attn:
577
+ return x
578
+ x_orig = x
579
+ x = self.norm1(x)
580
+ update = self.self_attn(x, x, x)
581
+ return x_orig + self.layer_scale_1(update)
582
+
583
+ def forward(self, x: torch.Tensor):
584
+ with ExitStack() as stack:
585
+ if x.device.type != 'cuda':
586
+ stack.enter_context(no_compile())
587
+ x = self._sa_block(x)
588
+ x = self._ff_block(x)
589
+ state = self._streaming_state
590
+ if state:
591
+ state.offset_cpu += x.shape[1]
592
+ return x
593
+
594
+
595
+ @dataclass
596
+ class _TransformerState:
597
+ offset: torch.Tensor
598
+
599
+ def reset(self):
600
+ self.offset.zero_()
601
+
602
+
603
+ class StreamingTransformer(StreamingModule[_TransformerState]):
604
+ """Transformer with Streaming / Causal support.
605
+
606
+ Args:
607
+ d_model (int): Dimension of the data.
608
+ num_heads (int): Number of heads.
609
+ dim_feedforward (int): Intermediate dimension of FF module.
610
+ causal (bool): Causal mask applied automatically.
611
+ context (int, optional): Receptive field for the causal mask, infinite if None.
612
+ layer_scale (float, optional): If not None, LayerScale will be used
613
+ with the given value as initial scale.
614
+ positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
615
+ max_period (float): Maximum period of the time embedding.
616
+ positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
617
+ layer_class: (subclass of `StreamingTransformerLayer): class to use
618
+ to initialize the layers, allowing further customization outside of AudioCraft.
619
+ device (torch.device, optional): Device on which to initialize.
620
+ dtype (torch.dtype, optional): dtype to use.
621
+ **kwargs: See `StreamingTransformerLayer`.
622
+ """
623
+
624
+ def __init__(
625
+ self,
626
+ d_model: int,
627
+ num_heads: int,
628
+ num_layers: int,
629
+ dim_feedforward: int | list[int] = 2048,
630
+ causal: bool = False,
631
+ context: tp.Optional[int] = None,
632
+ positional_embedding: str = "sin",
633
+ max_period: float = 10_000,
634
+ positional_scale: float = 1.0,
635
+ betas: tp.Optional[tp.Tuple[float, float]] = None,
636
+ layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
637
+ device=None,
638
+ dtype=None,
639
+ **kwargs,
640
+ ):
641
+ super().__init__()
642
+ assert d_model % num_heads == 0
643
+
644
+ self.positional_embedding = positional_embedding
645
+ self.max_period = max_period
646
+ self.positional_scale = positional_scale
647
+ self.betas = betas
648
+
649
+ assert positional_embedding in {"sin", "rope", "sin_rope", "none"}
650
+ self.rope: tp.Optional[RotaryEmbedding] = None
651
+ if self.positional_embedding in {"rope", "sin_rope"}:
652
+ self.rope = RotaryEmbedding(max_period=max_period)
653
+
654
+ self.layers = nn.ModuleList()
655
+ for _ in range(num_layers):
656
+ self.layers.append(
657
+ layer_class(
658
+ d_model=d_model,
659
+ num_heads=num_heads,
660
+ dim_feedforward=dim_feedforward,
661
+ causal=causal,
662
+ context=context,
663
+ rope=self.rope,
664
+ device=device,
665
+ dtype=dtype,
666
+ **kwargs,
667
+ )
668
+ )
669
+
670
+ def _init_streaming_state(self, batch_size: int) -> _TransformerState:
671
+ device = next(self.parameters()).device
672
+ return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
673
+
674
+ def forward(self, x: torch.Tensor, *args, **kwargs):
675
+ B, T, C = x.shape
676
+
677
+ state = self._streaming_state
678
+ if state is None:
679
+ offset = torch.zeros(1, dtype=torch.long, device=x.device)
680
+ else:
681
+ offset = state.offset
682
+
683
+ if self.positional_embedding in {"sin", "sin_rope"}:
684
+ positions = torch.arange(T, device=x.device).view(1, -1, 1)
685
+ positions = positions + offset.view(-1, 1, 1)
686
+ pos_emb = create_sin_embedding(
687
+ positions, C, max_period=self.max_period, dtype=x.dtype
688
+ )
689
+ x = x + self.positional_scale * pos_emb
690
+
691
+ for layer in self.layers:
692
+ x = layer(x, *args, **kwargs)
693
+
694
+ if state is not None:
695
+ state.offset.add_(T)
696
+ return x
697
+
698
+
699
+ class ProjectedTransformer(StreamingContainer):
700
+ """Transformer with optional projections of the input and output to different dimensions when needed.
701
+ Supports multiple outputs.
702
+
703
+ Args:
704
+ input_dimension (int): dimension of the input.
705
+ output_dimensions (tuple[int]): dimensions of the outputs.
706
+ d_model (int): inner dimension of the Transformer.
707
+ conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`.
708
+ Similarly, the output will have the same layout.
709
+ """
710
+
711
+ def __init__(
712
+ self,
713
+ input_dimension: int,
714
+ output_dimensions: tp.Tuple[int, ...],
715
+ d_model: int,
716
+ *,
717
+ conv_layout: bool = False,
718
+ **kwargs,
719
+ ):
720
+ super().__init__()
721
+ self.transformer = StreamingTransformer(d_model=d_model, **kwargs)
722
+ self.input_dimension = input_dimension
723
+ self.output_dimensions = output_dimensions
724
+ self.conv_layout = conv_layout
725
+ self.input_proj = None
726
+ if d_model != input_dimension:
727
+ self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
728
+
729
+ self.output_projs = nn.ModuleList()
730
+ for output_dimension in output_dimensions:
731
+ if d_model == output_dimension:
732
+ self.output_projs.append(nn.Identity())
733
+ else:
734
+ self.output_projs.append(
735
+ nn.Linear(d_model, output_dimension, bias=False)
736
+ )
737
+
738
+ def forward(self, x, *args, **kwargs):
739
+ if self.conv_layout:
740
+ x = x.transpose(1, 2)
741
+ if self.input_proj is not None:
742
+ x = self.input_proj(x)
743
+ z = self.transformer(x, *args, **kwargs)
744
+ ys = []
745
+ for output_proj in self.output_projs:
746
+ y = output_proj(z)
747
+ if self.conv_layout:
748
+ y = y.transpose(1, 2)
749
+ ys.append(y)
750
+ return ys
moshi/quantization/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """RVQ."""
11
+ # flake8: noqa
12
+ from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer
13
+ from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
moshi/quantization/base.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ """
12
+ Base class for all quantizers.
13
+ """
14
+
15
+ from dataclasses import dataclass, field
16
+ import typing as tp
17
+
18
+ import torch
19
+ from torch import nn
20
+
21
+
22
+ @dataclass
23
+ class QuantizedResult:
24
+ x: torch.Tensor
25
+ codes: torch.Tensor
26
+ bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
27
+ penalty: tp.Optional[torch.Tensor] = None
28
+ metrics: dict = field(default_factory=dict)
29
+
30
+
31
+ class BaseQuantizer(nn.Module):
32
+ """Base class for quantizers."""
33
+
34
+ def __init__(self):
35
+ super().__init__()
36
+ self._ema_frozen = False
37
+
38
+ def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
39
+ """
40
+ Given input tensor x, returns first the quantized (or approximately quantized)
41
+ representation along with quantized codes, bandwidth, and any penalty term for the loss.
42
+ Finally, this returns a dict of metrics to update logging etc.
43
+ Frame rate must be passed so that the bandwidth is properly computed.
44
+ """
45
+ raise NotImplementedError()
46
+
47
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
48
+ """Encode a given input tensor with the specified sample rate at the given bandwidth."""
49
+ raise NotImplementedError()
50
+
51
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
52
+ """Decode the given codes to the quantized representation."""
53
+ raise NotImplementedError()
54
+
55
+ @property
56
+ def cardinality(self) -> int:
57
+ """Cardinality of each codebook."""
58
+ raise NotImplementedError()
59
+
60
+ @property
61
+ def total_codebooks(self) -> int:
62
+ """Total number of codebooks."""
63
+ raise NotImplementedError()
64
+
65
+ @property
66
+ def num_codebooks(self) -> int:
67
+ """Number of active codebooks."""
68
+ raise NotImplementedError()
69
+
70
+ @property
71
+ def semantic_quantizer(self) -> 'BaseQuantizer':
72
+ """This returns the quantizer that models the first level of the hierarchy (typically semantic).
73
+
74
+ In this case, it's the quantizer itself.
75
+ """
76
+ return self
77
+
78
+ @property
79
+ def acoustic_quantizer(self) -> 'BaseQuantizer':
80
+ """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).
81
+
82
+ In this case, it's the quantizer itself.
83
+ """
84
+ return self
85
+
86
+ def set_num_codebooks(self, n: int) -> None:
87
+ """Set the number of active codebooks."""
88
+ raise NotImplementedError()
89
+
90
+ @property
91
+ def ema_frozen(self) -> bool:
92
+ """Whether to apply ema to the codebooks."""
93
+ return self._ema_frozen
94
+
95
+ def ema_frozen_(self, ema_frozen: bool) -> None:
96
+ """Set whether ema should be applied to the codebooks."""
97
+ self._ema_frozen = ema_frozen
98
+
99
+
100
+ class DummyQuantizer(BaseQuantizer):
101
+ """Fake quantizer that actually does not perform any quantization."""
102
+
103
+ def __init__(
104
+ self,
105
+ dimension: int,
106
+ input_dimension: tp.Optional[int] = None,
107
+ output_dimension: tp.Optional[int] = None,
108
+ ):
109
+ super().__init__()
110
+ self.dimension = dimension
111
+ self.input_dimension = input_dimension or dimension
112
+ self.output_dimension = output_dimension or dimension
113
+ self.input_proj: torch.nn.Module
114
+ self.output_proj: torch.nn.Module
115
+ if self.input_dimension == self.dimension:
116
+ self.input_proj = torch.nn.Identity()
117
+ else:
118
+ self.input_proj = torch.nn.Conv1d(
119
+ self.input_dimension, self.dimension, 1, bias=False
120
+ )
121
+ if self.input_dimension == self.dimension:
122
+ self.output_proj = torch.nn.Identity()
123
+ else:
124
+ self.output_proj = torch.nn.Conv1d(
125
+ self.dimension, self.output_dimension, 1, bias=False
126
+ )
127
+
128
+ def forward(self, x: torch.Tensor, frame_rate: int):
129
+ q = x.unsqueeze(1)
130
+ x = self.output_proj(self.input_proj(x))
131
+ return QuantizedResult(
132
+ x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)
133
+ )
134
+
135
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
136
+ """Encode a given input tensor with the specified sample rate at the given bandwidth.
137
+ In the case of the DummyQuantizer, the codes are actually identical
138
+ to the input and resulting quantized representation as no quantization is done.
139
+ """
140
+ x = self.input_proj(x)
141
+ return x.unsqueeze(1)
142
+
143
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
144
+ """Decode the given codes to the quantized representation.
145
+ In the case of the DummyQuantizer, the codes are actually identical
146
+ to the input and resulting quantized representation as no quantization is done.
147
+ """
148
+ y = codes.squeeze(1)
149
+ return self.output_proj(y)
150
+
151
+ @property
152
+ def total_codebooks(self):
153
+ """Total number of codebooks."""
154
+ return 1
155
+
156
+ @property
157
+ def num_codebooks(self):
158
+ """Total number of codebooks."""
159
+ return self.total_codebooks
160
+
161
+ def set_num_codebooks(self, n: int):
162
+ """Set the number of active codebooks."""
163
+ raise AttributeError(
164
+ "Cannot override the number of codebooks for the dummy quantizer"
165
+ )
166
+
167
+ @property
168
+ def cardinality(self) -> int:
169
+ """Cardinality of each codebook."""
170
+ return 1
moshi/quantization/core_vq.py ADDED
@@ -0,0 +1,384 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import typing as tp
12
+
13
+ from einops import rearrange
14
+ import torch
15
+ from torch import nn
16
+ from torch import distributed
17
+ import torch.nn.functional as F
18
+
19
+
20
+ class _CodebookForwardResult(tp.NamedTuple):
21
+ quantized: torch.Tensor
22
+ codes: torch.Tensor
23
+ metrics: tp.Dict[str, torch.Tensor]
24
+
25
+
26
+ class _VQForwardResult(tp.NamedTuple):
27
+ quantized: torch.Tensor
28
+ codes: torch.Tensor
29
+ loss: torch.Tensor
30
+ metrics: tp.Dict[str, torch.Tensor]
31
+
32
+
33
+ def _ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, decay: float) -> None:
34
+ moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
35
+
36
+
37
+ def _uniform_init(*shape: int) -> torch.Tensor:
38
+ t = torch.empty(shape)
39
+ nn.init.kaiming_uniform_(t)
40
+ return t
41
+
42
+
43
+ def _sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
44
+ num_samples, device = samples.shape[0], samples.device
45
+
46
+ if num_samples >= num:
47
+ indices = torch.randperm(num_samples, device=device)[:num]
48
+ else:
49
+ indices = torch.randint(0, num_samples, (num,), device=device)
50
+
51
+ return samples[indices]
52
+
53
+
54
+ def _compute_entropy(usage: torch.Tensor) -> torch.Tensor:
55
+ # Usage is some unnormalized distribution.
56
+ proba = usage / usage.sum()
57
+ p_log_p = torch.where(
58
+ proba == 0, zero_scalar(usage.device), proba * torch.log(proba)
59
+ )
60
+ return -p_log_p.sum()
61
+
62
+
63
+ def _is_distributed() -> bool:
64
+ # Checks if we need to use distributed routines.
65
+ return distributed.is_initialized() and distributed.get_world_size() > 1
66
+
67
+
68
+ def zero_scalar(device) -> torch.Tensor:
69
+ """Returns a 0. value on the given device without introducing a synchronization point."""
70
+ return torch.zeros([1], device=device)[0]
71
+
72
+
73
+ class EuclideanCodebook(nn.Module):
74
+ """Codebook with Euclidean distance.
75
+
76
+ Args:
77
+ dim (int): Dimension.
78
+ codebook_size (int): Codebook size.
79
+ decay (float): Decay for exponential moving average over the codebooks.
80
+ epsilon (float): Epsilon value for numerical stability.
81
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
82
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
83
+ a uniform distribution, so that it doesn't depend on the batch size etc.
84
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
85
+ to avoid the centroid getting replaced too quickly.
86
+ check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
87
+ This is to avoid too many synchronization points.
88
+
89
+ Buffers:
90
+ cluster_usage (torch.Tensor): EMA of the cluster usage per batch, e.g. this will
91
+ be dependent on the batch size etc.
92
+ embedding_sum (torch.Tensor): EMA of the sum of the assigned points to each cluster.
93
+ In particular, this can be normalized by `cluster_usage` to obtain the
94
+ actual cluster centroids.
95
+ """
96
+
97
+ def __init__(
98
+ self,
99
+ dim: int,
100
+ codebook_size: int,
101
+ decay: float = 0.99,
102
+ epsilon: float = 1e-5,
103
+ threshold_usage_ratio: float = 0.1,
104
+ replaced_usage_ratio: float = 1.0,
105
+ check_unused_every: int = 5,
106
+ ):
107
+ super().__init__()
108
+ self.decay = decay
109
+ embedding = torch.zeros(codebook_size, dim)
110
+
111
+ self.dim = dim
112
+ self.codebook_size = codebook_size
113
+
114
+ self.epsilon = epsilon
115
+ self.threshold_usage_ratio = threshold_usage_ratio
116
+ self.replaced_usage_ratio = replaced_usage_ratio
117
+ self.check_unused_every = check_unused_every
118
+ self._next_unused_check = check_unused_every
119
+
120
+ self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
121
+ self.register_buffer("cluster_usage", torch.ones(codebook_size))
122
+ self.register_buffer("embedding_sum", embedding)
123
+ self.register_buffer("_embedding", None, persistent=False)
124
+ self._cached_initialized = False
125
+
126
+ def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
127
+ # Mapping old names to new names
128
+ mappings = {
129
+ "inited": "_initialized",
130
+ "cluster_size": "cluster_usage",
131
+ "embed_avg": "embedding_sum",
132
+ "embed_sum": "embedding_sum",
133
+ }
134
+ for old_name, new_name in mappings.items():
135
+ old_name = prefix + old_name
136
+ if old_name in state_dict:
137
+ value = state_dict.pop(old_name)
138
+ if new_name is not None:
139
+ state_dict[prefix + new_name] = value
140
+ super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
141
+
142
+ @property
143
+ def embedding(self) -> torch.Tensor:
144
+ if self._embedding is None:
145
+ embedding = (
146
+ self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
147
+ )
148
+ self.register_buffer("_embedding", embedding, persistent=False)
149
+ return embedding
150
+ return self._embedding
151
+
152
+ def _broadcast_buffers(self) -> None:
153
+ if _is_distributed():
154
+ for buffer in self.buffers():
155
+ distributed.broadcast(buffer, 0)
156
+
157
+ def _replace_expired_codes(self, samples: torch.Tensor, mask: torch.Tensor) -> None:
158
+ # Replaces expired centroids, as indicated by `mask` (a true value indicate the code needs to be replaced).
159
+ # The new codes are sampled from the batch `samples`.
160
+ new_vectors = _sample_vectors(samples, self.codebook_size)
161
+ replace_cluster_usage = (
162
+ self.replaced_usage_ratio * self.cluster_usage.sum() / self.codebook_size
163
+ )
164
+ self.embedding_sum[:] = torch.where(
165
+ mask[:, None], replace_cluster_usage * new_vectors, self.embedding_sum
166
+ )
167
+ self.cluster_usage[:] = torch.where(
168
+ mask, replace_cluster_usage, self.cluster_usage
169
+ )
170
+
171
+ def _reshape_input(self, x: torch.Tensor) -> torch.Tensor:
172
+ # Flattens all the dimensions but the last one, e.g. return a vector of shape `[N, D]`.
173
+ x = rearrange(x, "... d -> (...) d")
174
+ return x
175
+
176
+ def _reshape_codes(self, codes: torch.Tensor, shape: torch.Size) -> torch.Tensor:
177
+ return codes.view(*shape[:-1])
178
+
179
+ def _quantize(self, x: torch.Tensor) -> torch.Tensor:
180
+ # Projects each vector in `x` over the nearest centroid and return its index.
181
+ # `x` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
182
+ assert x.dim() == 2
183
+ dists = torch.cdist(x[None], self.embedding[None], p=2)[0]
184
+ codes = dists.argmin(dim=-1)
185
+ return codes
186
+
187
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
188
+ """Given a tensor `x` of shape `[*, D]`, returns a tensor of integer codes of shape `[*]`.
189
+ The codes are defined as the indexes of the centroids nearest to each vector in `x`.
190
+ """
191
+ assert x.dtype.is_floating_point, f"Input should be floats, got {x.dtype}"
192
+ shape = x.shape
193
+ x = self._reshape_input(x)
194
+ codes = self._quantize(x)
195
+ codes = self._reshape_codes(codes, shape)
196
+ return codes
197
+
198
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
199
+ """Given a tensor of codes of shape `[*]`, returns a tensor of shape `[*, D]`,
200
+ corresponding to the centroids associated to each code index.
201
+ """
202
+ assert (
203
+ not codes.dtype.is_floating_point
204
+ ), f"Codes should be integers, got {codes.dtype}"
205
+ quantized = F.embedding(codes, self.embedding)
206
+ return quantized
207
+
208
+ def forward(
209
+ self, x: torch.Tensor, initialize: bool = True
210
+ ) -> _CodebookForwardResult:
211
+ shape = x.shape
212
+ x = self._reshape_input(x)
213
+
214
+ flat_codes = self._quantize(x)
215
+ codes = self._reshape_codes(flat_codes, shape)
216
+ quantized = self.decode(codes)
217
+ metrics: tp.Dict[str, torch.Tensor] = {}
218
+
219
+ return _CodebookForwardResult(quantized, codes, metrics)
220
+
221
+
222
+ class VectorQuantization(nn.Module):
223
+ """Vector quantization implementation.
224
+ Currently supports only euclidean distance.
225
+
226
+ Args:
227
+ dim (int): Dimension
228
+ codebook_size (int): Codebook size
229
+ codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
230
+ decay (float): Decay for exponential moving average over the codebooks.
231
+ epsilon (float): Epsilon value for numerical stability.
232
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
233
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
234
+ a uniform distribution, so that it doesn't depend on the batch size etc.
235
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
236
+ to avoid the centroid getting replaced too quickly.
237
+ check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
238
+ This is to avoid too many synchronization points.
239
+ """
240
+
241
+ def __init__(
242
+ self,
243
+ dim: int,
244
+ codebook_size: int,
245
+ codebook_dim: tp.Optional[int] = None,
246
+ decay: float = 0.99,
247
+ epsilon: float = 1e-5,
248
+ threshold_usage_ratio: float = 0.1,
249
+ **kwargs,
250
+ ):
251
+ super().__init__()
252
+ if codebook_dim is None:
253
+ codebook_dim = dim
254
+
255
+ requires_projection = codebook_dim != dim
256
+ self.project_in = (
257
+ nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
258
+ )
259
+ self.project_out = (
260
+ nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
261
+ )
262
+ self.epsilon = epsilon
263
+ self._codebook = EuclideanCodebook(
264
+ dim=codebook_dim,
265
+ codebook_size=codebook_size,
266
+ decay=decay,
267
+ epsilon=epsilon,
268
+ threshold_usage_ratio=threshold_usage_ratio,
269
+ **kwargs,
270
+ )
271
+ self.codebook_size = codebook_size
272
+
273
+ @property
274
+ def embedding(self):
275
+ return self._codebook.embedding
276
+
277
+ def _rearrange_input(self, x):
278
+ x = rearrange(x, "b d n -> b n d")
279
+ return x
280
+
281
+ def _rearrange_output(self, quantized):
282
+ quantized = rearrange(quantized, "b n d -> b d n")
283
+ return quantized
284
+
285
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
286
+ """Encodes `x` into discrete integer codes."""
287
+ x = self._rearrange_input(x)
288
+ x = self.project_in(x)
289
+ codes = self._codebook.encode(x)
290
+ return codes
291
+
292
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
293
+ """Converts integer codes into quantized vectors."""
294
+ quantized = self._codebook.decode(codes)
295
+ quantized = self.project_out(quantized)
296
+ quantized = self._rearrange_output(quantized)
297
+ return quantized
298
+
299
+ def forward(self, x: torch.Tensor, initialize: bool = True) -> _VQForwardResult:
300
+ x = self._rearrange_input(x)
301
+ quantized, codes, metrics = self._codebook(x, initialize=initialize)
302
+
303
+ loss = zero_scalar(x.device)
304
+
305
+ quantized = self.project_out(quantized)
306
+ quantized = self._rearrange_output(quantized)
307
+
308
+ return _VQForwardResult(quantized, codes, loss, metrics)
309
+
310
+
311
+ class ResidualVectorQuantization(nn.Module):
312
+ """Residual vector quantization implementation.
313
+
314
+ Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
315
+ """
316
+
317
+ def __init__(self, *, num_quantizers: int, codebook_offset: int, **kwargs):
318
+ super().__init__()
319
+ self.layers = nn.ModuleList(
320
+ [VectorQuantization(**kwargs) for _ in range(num_quantizers)]
321
+ )
322
+ self.codebook_offset = codebook_offset
323
+
324
+ def forward(
325
+ self, x: torch.Tensor, n_q: tp.Optional[int] = None
326
+ ) -> _VQForwardResult:
327
+ """
328
+ Args:
329
+ x (torch.Tensor): input tensor to quantize, of shape `[B, C, T]`.
330
+ n_q (int or None): if provided, number of codebook levels to use in RVQ.
331
+ """
332
+
333
+ quantized_out = zero_scalar(x.device)
334
+ residual = x
335
+
336
+ all_losses = []
337
+ all_codes = []
338
+ all_metrics: tp.Dict[str, torch.Tensor] = {}
339
+
340
+ n_q = n_q or len(self.layers)
341
+ previous_layer_is_initialized = True
342
+
343
+ for i, layer in enumerate(self.layers[:n_q]): # type: ignore
344
+ quantized, codes, loss, metrics = layer(
345
+ residual, initialize=previous_layer_is_initialized
346
+ )
347
+
348
+ quantized = quantized.detach()
349
+ residual = residual - quantized
350
+ quantized_out = quantized_out + quantized
351
+
352
+ all_codes.append(codes)
353
+ all_losses.append(loss)
354
+
355
+ for key, value in metrics.items():
356
+ if key in all_metrics:
357
+ all_metrics[key] += value / n_q
358
+ else:
359
+ all_metrics[key] = value / n_q
360
+ all_metrics[key + f"_{i + self.codebook_offset}"] = value
361
+
362
+ out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
363
+ return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)
364
+
365
+ def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
366
+ """Encodes `x` into discrete integer codes. If `n_q` is provided, only uses the first `n_q` codebook levels."""
367
+ residual = x
368
+ all_indices = []
369
+ n_q = n_q or len(self.layers)
370
+ for layer in self.layers[:n_q]: # type: ignore
371
+ indices = layer.encode(residual)
372
+ quantized = layer.decode(indices)
373
+ residual = residual - quantized
374
+ all_indices.append(indices)
375
+ out_indices = torch.stack(all_indices)
376
+ return out_indices
377
+
378
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
379
+ """Converts the integer codes into quantized vectors."""
380
+ quantized = zero_scalar(codes.device)
381
+ for idx, layer_codes in enumerate(codes):
382
+ layer = self.layers[idx]
383
+ quantized = quantized + layer.decode(layer_codes)
384
+ return quantized
moshi/quantization/vq.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import math
12
+ import typing as tp
13
+
14
+ import torch
15
+
16
+ from .base import BaseQuantizer, QuantizedResult
17
+ from .core_vq import ResidualVectorQuantization
18
+
19
+
20
+ class ResidualVectorQuantizer(BaseQuantizer):
21
+ """Residual Vector Quantizer.
22
+
23
+ Args:
24
+ dimension (int): Dimension of the codebooks.
25
+ input_dimension (None or int): dimension of the input, defaults to `dimension` if not provided.
26
+ output_dimension (None or int): dimension of the output, defaults to `dimension` if not provided.
27
+ n_q (int): Number of vector quantizers used.
28
+ q_dropout (bool): Random quantizer drop out at train time.
29
+ no_quantization_rate (float): Gives the probability of applying no quantization at all
30
+ at train time. The RVQ codebooks will still get the input value to learn the proper codebook.
31
+ bins (int): Codebook size.
32
+ decay (float): Decay for exponential moving average over the codebooks.
33
+ threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
34
+ is replaced. This is expressed as a fraction of the usage a centroid would get under
35
+ a uniform distribution, so that it doesn't depend on the batch size etc.
36
+ replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
37
+ to avoid the centroid getting replaced too quickly.
38
+ codebook_offset (int): Offset to use for the codebook indices. This is useful when using multiple quantizers
39
+ such as in SplitResidualVectorQuantizer.
40
+ force_projection (bool): Whether to force input and output projections even when dimension is constant.
41
+ generator_seed (int or None): seed used to initialize the RNG used for no quantization.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ dimension: int = 128,
47
+ input_dimension: tp.Optional[int] = None,
48
+ output_dimension: tp.Optional[int] = None,
49
+ n_q: int = 8,
50
+ q_dropout: bool = False,
51
+ q_first_only_proba: float = 0.0,
52
+ no_quantization_rate: float = 0.0,
53
+ bins: int = 1024,
54
+ decay: float = 0.99,
55
+ threshold_usage_ratio: float = 0.1,
56
+ replaced_usage_ratio: float = 1.0,
57
+ codebook_offset: int = 0,
58
+ force_projection: bool = False,
59
+ generator_seed: tp.Optional[int] = None,
60
+ ):
61
+ super().__init__()
62
+ self.max_n_q = n_q
63
+ self.n_q = n_q
64
+ self.q_dropout = q_dropout
65
+ self.no_quantization_rate = no_quantization_rate
66
+ self.q_first_only_proba = q_first_only_proba
67
+ self.dimension = dimension
68
+ self.input_dimension = input_dimension or dimension
69
+ self.output_dimension = output_dimension or dimension
70
+ self.bins = bins
71
+ self.decay = decay
72
+ self.input_proj: torch.nn.Module
73
+ self.output_proj: torch.nn.Module
74
+ self.generator = None
75
+ if generator_seed is not None:
76
+ self.generator = torch.Generator(
77
+ device="cuda" if torch.cuda.is_available() else "cpu"
78
+ )
79
+ self.generator.manual_seed(generator_seed)
80
+ if self.input_dimension == self.dimension and not force_projection:
81
+ self.input_proj = torch.nn.Identity()
82
+ else:
83
+ self.input_proj = torch.nn.Conv1d(
84
+ self.input_dimension, self.dimension, 1, bias=False
85
+ )
86
+ if self.output_dimension == self.dimension and not force_projection:
87
+ self.output_proj = torch.nn.Identity()
88
+ else:
89
+ self.output_proj = torch.nn.Conv1d(
90
+ self.dimension, self.output_dimension, 1, bias=False
91
+ )
92
+ self.vq = ResidualVectorQuantization(
93
+ dim=self.dimension,
94
+ codebook_size=self.bins,
95
+ num_quantizers=self.n_q,
96
+ decay=self.decay,
97
+ threshold_usage_ratio=threshold_usage_ratio,
98
+ replaced_usage_ratio=replaced_usage_ratio,
99
+ codebook_offset=codebook_offset,
100
+ )
101
+
102
+ def forward(self, x: torch.Tensor, frame_rate: int):
103
+ """
104
+ Args:
105
+ x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
106
+ frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
107
+ the bandwidth.
108
+
109
+ Returns:
110
+ QuantizedResult: Quantized result with the following attributes:
111
+ - `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
112
+ - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
113
+ - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
114
+ - `penalty` (torch.Tensor): Commitment loss.
115
+ - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
116
+ """
117
+ n_q = self.n_q
118
+ x = self.input_proj(x)
119
+
120
+ bw_per_q = math.log2(self.bins) * frame_rate / 1000
121
+ quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q)
122
+ B, _, _ = quantized.shape
123
+ quantized = self.output_proj(quantized)
124
+ codes = codes.transpose(0, 1)
125
+ # codes is [B, K, T], with T frames, K nb of codebooks.
126
+ bw = torch.tensor(n_q * bw_per_q).to(x)
127
+ return QuantizedResult(
128
+ quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics
129
+ )
130
+
131
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
132
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
133
+ The RVQ encode method sets the appropriate number of quantizer to use
134
+ and returns indices for each quantizer.
135
+ """
136
+ n_q = self.n_q
137
+ if x.shape[-1] == 0:
138
+ return torch.empty((x.shape[0], n_q, 0), device=x.device, dtype=torch.int64)
139
+
140
+ x = self.input_proj(x)
141
+ codes = self.vq.encode(x, n_q=n_q)
142
+ codes = codes.transpose(0, 1)
143
+ # codes is [B, K, T], with T frames, K nb of codebooks.
144
+ return codes
145
+
146
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
147
+ """Decode the given codes to the quantized representation."""
148
+ # codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
149
+ codes = codes.transpose(0, 1)
150
+ quantized = self.vq.decode(codes)
151
+ quantized = self.output_proj(quantized)
152
+ return quantized
153
+
154
+ @property
155
+ def total_codebooks(self):
156
+ return self.max_n_q
157
+
158
+ @property
159
+ def num_codebooks(self):
160
+ return self.n_q
161
+
162
+ def set_num_codebooks(self, n: int):
163
+ assert n >= 0 and n <= self.max_n_q
164
+ self.n_q = n
165
+
166
+ @property
167
+ def cardinality(self) -> int:
168
+ return self.bins
169
+
170
+
171
+ class SplitResidualVectorQuantizer(BaseQuantizer):
172
+ """Residual Vector Quantizer with separate projections for the first quantizer and the rest.
173
+
174
+ Args:
175
+ n_q (int): Number of residual vector quantizers used.
176
+ n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
177
+ no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go
178
+ through the sub quantizers. If `independent`, independent decisions are taken by
179
+ the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both.
180
+ **kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
181
+ """
182
+
183
+ def __init__(
184
+ self,
185
+ *,
186
+ n_q: int = 8,
187
+ no_quantization_rate: float = 0.0,
188
+ no_quantization_mode: str = "same",
189
+ n_q_semantic: int = 1,
190
+ **kwargs,
191
+ ):
192
+ super().__init__()
193
+ assert n_q > n_q_semantic, (
194
+ f"Number of quantizers {n_q} must be larger "
195
+ f"than the number of semantic quantizers {n_q_semantic}."
196
+ )
197
+ self.max_n_q = n_q
198
+ self.n_q_semantic = n_q_semantic
199
+ self.n_q_acoustic = n_q - n_q_semantic
200
+ if no_quantization_mode == "true_skip":
201
+ self.no_quantization_rate = no_quantization_rate
202
+ # Setting to zero for the underlying RVQ.
203
+ no_quantization_rate = 0.0
204
+ else:
205
+ self.no_quantization_rate = 0.0
206
+ if no_quantization_mode == "same":
207
+ kwargs["generator_seed"] = 1234
208
+ kwargs["no_quantization_rate"] = no_quantization_rate
209
+ q_dropout = kwargs.pop("q_dropout", False)
210
+ self.rvq_first = ResidualVectorQuantizer(
211
+ n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
212
+ )
213
+ self.rvq_rest = ResidualVectorQuantizer(
214
+ n_q=n_q - n_q_semantic,
215
+ codebook_offset=1,
216
+ force_projection=True,
217
+ q_dropout=q_dropout,
218
+ **kwargs,
219
+ )
220
+ if no_quantization_mode == "true_skip":
221
+ assert self.rvq_first.input_dimension == self.rvq_first.output_dimension
222
+ assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension
223
+
224
+ def _renorm_and_add(
225
+ self,
226
+ first_val: torch.Tensor,
227
+ rest_val: torch.Tensor,
228
+ n_q_semantic: int,
229
+ n_q_acoustic: int,
230
+ ):
231
+ """Renormalizes values from `rvq_first` and `rvq_rest` and adds them.
232
+
233
+ This allows correcting statistics that are normalized by the number of quantizers. To renormalize, we use the
234
+ number of quantizers that are actually used, e.g. taking into account quantizer dropout.
235
+ """
236
+ n_q = n_q_semantic + n_q_acoustic
237
+ renorm_first_val = first_val * n_q_semantic / n_q
238
+ renorm_rest_val = rest_val * n_q_acoustic / n_q
239
+ return renorm_first_val + renorm_rest_val
240
+
241
+ def forward(self, x: torch.Tensor, frame_rate: int):
242
+ """
243
+ Args:
244
+ x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
245
+ frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
246
+ the bandwidth.
247
+
248
+ Returns:
249
+ QuantizedResult: Quantized result with the following attributes:
250
+ - `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
251
+ - `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
252
+ - `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
253
+ - `penalty` (torch.Tensor): Commitment loss.
254
+ - `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
255
+ """
256
+ semantic_result = self.rvq_first(x, frame_rate)
257
+ if self.n_q == self.n_q_semantic:
258
+ return semantic_result
259
+ acoustic_result = self.rvq_rest(x, frame_rate)
260
+ full_quantized_emb = semantic_result.x + acoustic_result.x
261
+ full_quantized_codes = torch.cat(
262
+ [semantic_result.codes, acoustic_result.codes], dim=1
263
+ )
264
+ # This is the actual number of quantizers used, e.g. taking into account quantizer dropout.
265
+ n_q_semantic = semantic_result.codes.shape[1]
266
+ n_q_acoustic = acoustic_result.codes.shape[1]
267
+ full_quantized_bandwidth = semantic_result.bandwidth + acoustic_result.bandwidth
268
+ full_quantized_penalty = self._renorm_and_add(
269
+ semantic_result.penalty, acoustic_result.penalty, n_q_semantic, n_q_acoustic
270
+ )
271
+ full_quantized_metrics = semantic_result.metrics
272
+ for key, value in acoustic_result.metrics.items():
273
+ if key in full_quantized_metrics:
274
+ full_quantized_metrics[key] = self._renorm_and_add(
275
+ full_quantized_metrics[key], value, n_q_semantic, n_q_acoustic
276
+ )
277
+ else:
278
+ full_quantized_metrics[key] = value
279
+ return QuantizedResult(
280
+ full_quantized_emb,
281
+ full_quantized_codes,
282
+ full_quantized_bandwidth,
283
+ penalty=full_quantized_penalty,
284
+ metrics=full_quantized_metrics,
285
+ )
286
+
287
+ def encode(self, x: torch.Tensor) -> torch.Tensor:
288
+ """Encode a given input tensor with the specified frame rate at the given bandwidth.
289
+ The RVQ encode method sets the appropriate number of quantizer to use
290
+ and returns indices for each quantizer.
291
+ """
292
+ codes = self.rvq_first.encode(x)
293
+ if self.n_q > self.n_q_semantic:
294
+ acoustic_codes = self.rvq_rest.encode(x)
295
+ codes = torch.cat([codes, acoustic_codes], dim=1)
296
+ # codes is [B, K, T], with T frames, K nb of codebooks.
297
+ return codes
298
+
299
+ def decode(self, codes: torch.Tensor) -> torch.Tensor:
300
+ """Decode the given codes to the quantized representation."""
301
+ # codes is [B, K, T], with T frames, K nb of codebooks.
302
+ quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
303
+ if codes.shape[1] > self.n_q_semantic:
304
+ quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
305
+ return quantized
306
+
307
+ @property
308
+ def total_codebooks(self):
309
+ return self.rvq_first.max_n_q + self.rvq_rest.max_n_q
310
+
311
+ @property
312
+ def num_codebooks(self):
313
+ return self.rvq_first.num_codebooks + self.rvq_rest.num_codebooks
314
+
315
+ @property
316
+ def n_q(self):
317
+ return self.rvq_first.n_q + self.rvq_rest.n_q
318
+
319
+ @property
320
+ def dimension(self):
321
+ return self.rvq_first.dimension
322
+
323
+ @property
324
+ def semantic_quantizer(self) -> ResidualVectorQuantizer:
325
+ """This returns the quantizer that models the first level of the hierarchy (typically semantic)."""
326
+ return self.rvq_first
327
+
328
+ @property
329
+ def acoustic_quantizer(self) -> ResidualVectorQuantizer:
330
+ """This returns the quantizer that models the higher levels of the hierarchy (typically acoustic)."""
331
+ return self.rvq_rest
332
+
333
+ def set_num_codebooks(self, n: int):
334
+ assert n >= self.n_q_semantic and n <= self.total_codebooks
335
+ self.rvq_rest.set_num_codebooks(n - self.n_q_semantic)
336
+
337
+ @property
338
+ def cardinality(self) -> int:
339
+ assert self.rvq_rest.cardinality == self.rvq_first.cardinality
340
+ return self.rvq_first.cardinality
moshi/utils/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+ """Utilities."""
moshi/utils/autocast.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+ import torch
12
+
13
+
14
+ class TorchAutocast:
15
+ """TorchAutocast utility class.
16
+ Allows you to enable and disable autocast. This is specially useful
17
+ when dealing with different architectures and clusters with different
18
+ levels of support.
19
+
20
+ Args:
21
+ enabled (bool): Whether to enable torch.autocast or not.
22
+ args: Additional args for torch.autocast.
23
+ kwargs: Additional kwargs for torch.autocast
24
+ """
25
+
26
+ def __init__(self, enabled: bool, *args, **kwargs):
27
+ self.autocast = torch.autocast(*args, **kwargs) if enabled else None
28
+
29
+ def __enter__(self):
30
+ if self.autocast is None:
31
+ return
32
+ try:
33
+ self.autocast.__enter__()
34
+ except RuntimeError:
35
+ device = self.autocast.device
36
+ dtype = self.autocast.fast_dtype
37
+ raise RuntimeError(
38
+ f"There was an error autocasting with dtype={dtype} device={device}\n"
39
+ "If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
40
+ )
41
+
42
+ def __exit__(self, *args, **kwargs):
43
+ if self.autocast is None:
44
+ return
45
+ self.autocast.__exit__(*args, **kwargs)
moshi/utils/compile.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ """
6
+ Provides some extra utilities around torch compile, in particular with a way
7
+ to fully deactivate it easily with a context manager.
8
+ Provides a simple activation checkpointing that is compatible with FSDP and torch compile.
9
+ Finally, provides some utilities for CUDA graphing functions.
10
+ """
11
+ from contextlib import contextmanager
12
+ from functools import wraps
13
+ import inspect
14
+ import os
15
+ import typing as tp
16
+
17
+ import torch
18
+ from torch import cuda
19
+
20
+
21
+ _compile_disabled: bool = False
22
+
23
+
24
+ @contextmanager
25
+ def no_compile():
26
+ """Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
27
+ global _compile_disabled
28
+
29
+ prev_disabled = _compile_disabled
30
+ _compile_disabled = True
31
+ try:
32
+ yield
33
+ finally:
34
+ _compile_disabled = prev_disabled
35
+
36
+
37
+ def torch_compile_lazy(fun):
38
+ """torch.compile creates a huge pool of processes, even when not using the function at all,
39
+ e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
40
+ """
41
+ if os.environ.get("NO_TORCH_COMPILE"):
42
+ return fun
43
+ fun_compiled = None
44
+
45
+ @wraps(fun)
46
+ def _wrapped(*args, **kwargs):
47
+ nonlocal fun_compiled
48
+ if _compile_disabled:
49
+ return fun(*args, **kwargs)
50
+ if fun_compiled is None:
51
+ fun_compiled = torch.compile(fun)
52
+ return fun_compiled(*args, **kwargs)
53
+
54
+ return _wrapped
55
+
56
+
57
+ class Checkpoint(torch.autograd.Function):
58
+ @staticmethod
59
+ def forward(ctx, function, *args) -> tp.Any:
60
+ to_save = []
61
+ ctx.others = []
62
+ ctx.function = function
63
+ # Sources will indicate whether the arg in position N is
64
+ # a tensor stored in ctx.save_for_backward, or inside ctx.others.
65
+ ctx.sources = []
66
+ new_args = []
67
+ for arg in args:
68
+ if isinstance(arg, torch.Tensor):
69
+ to_save.append(arg)
70
+ ctx.sources.append("tensor")
71
+ new_args.append(arg.detach())
72
+ else:
73
+ ctx.sources.append("other")
74
+ ctx.others.append(arg)
75
+ new_args.append(arg)
76
+ ctx.save_for_backward(*to_save)
77
+ # During the forward, we just make a pass with no gradient computed.
78
+ with torch.no_grad():
79
+ res = function(*new_args)
80
+ return res
81
+
82
+ @staticmethod
83
+ def backward(ctx, *grads) -> tp.Tuple[tp.Optional[torch.Tensor], ...]:
84
+ pseudo_tensors = []
85
+ with torch.set_grad_enabled(True):
86
+ # We create leaf tensors to collect the output gradients.
87
+ # We call them pseudo_tensors because they are pretending to be the input
88
+ # to `function` but are not directly
89
+ for tensor in ctx.saved_tensors:
90
+ pseudo_tensor = tensor.detach()
91
+ pseudo_tensor.requires_grad_(True)
92
+ pseudo_tensors.append(pseudo_tensor)
93
+ pseudo_tensors_copy = list(pseudo_tensors)
94
+ args = []
95
+ for source in ctx.sources:
96
+ if source == "other":
97
+ args.append(ctx.others.pop(0))
98
+ else:
99
+ assert source == "tensor"
100
+ args.append(pseudo_tensors_copy.pop(0))
101
+ res = ctx.function(*args)
102
+ # The second forward with grad computation allows us to connect the input leaf tensors
103
+ # inside pseudo_tensors, to the outputs of the function called.
104
+ if not isinstance(res, tuple):
105
+ res = (res,)
106
+ # Now we just ask Torch to compute the derivative of `res` given the gradient coming from above
107
+ # `grads`. The computed gradient will end up into the `pseudo_tensors` grad attributes.
108
+ torch.autograd.backward(res, grads)
109
+ out: tp.List[tp.Optional[torch.Tensor]] = [None]
110
+ for source in ctx.sources:
111
+ # We still need to output `None` values for non tensor parameters.
112
+ if source == "other":
113
+ out.append(None)
114
+ else:
115
+ assert source == "tensor"
116
+ out.append(pseudo_tensors.pop(0).grad)
117
+ return tuple(out)
118
+
119
+
120
+ def simple_checkpoint(module: torch.nn.Module, *args, **kwargs):
121
+ """Custom implementation of checkpointing in PyTorch as the builtin implementation is broken
122
+ when using torch compile. Only supports wrapping a `nn.Module` with a forward with no `*args` or `**kwargs`.
123
+
124
+ https://github.com/pytorch/pytorch/issues/97436.
125
+ Should be resolved in nightlies, but it is quite fun and simple to code it ourselves.
126
+ """
127
+ if hasattr(module, "_fsdp_wrapped_module"):
128
+ module_for_sig = module._fsdp_wrapped_module
129
+ else:
130
+ module_for_sig = module
131
+ sig = inspect.signature(module_for_sig.forward)
132
+ # We first flatten all arguments to use only *args, to make things easier and because
133
+ # torch.autograd.Function has weird support for kwargs.
134
+ bounded = sig.bind(*args, **kwargs)
135
+ new_args = []
136
+ for name, param in sig.parameters.items():
137
+ if param.kind in {
138
+ inspect.Parameter.VAR_POSITIONAL,
139
+ inspect.Parameter.VAR_KEYWORD,
140
+ }:
141
+ raise RuntimeError("simple_checkpoint doesn't support var args.")
142
+ if name not in bounded.arguments:
143
+ break
144
+ new_args.append(bounded.arguments[name])
145
+ return Checkpoint.apply(module, *new_args)
146
+
147
+
148
+ _in_cuda_graph = False
149
+ _disable_cuda_graph = False
150
+
151
+
152
+ def in_cuda_graph() -> bool:
153
+ """Indicate whether we are in a function that is CUDA Graphed (or will be soon)."""
154
+ return _in_cuda_graph
155
+
156
+
157
+ @contextmanager
158
+ def _set_in_cuda_graph():
159
+ global _in_cuda_graph
160
+ assert not _in_cuda_graph
161
+ _in_cuda_graph = True
162
+ try:
163
+ yield
164
+ finally:
165
+ _in_cuda_graph = False
166
+
167
+
168
+ def _is_cuda_graph_enabled() -> bool:
169
+ if _disable_cuda_graph:
170
+ return False
171
+ no_cuda_graph = os.environ.get("NO_CUDA_GRAPH", "")
172
+ if no_cuda_graph.lower() not in {"0", "no", "n", ""}:
173
+ return False
174
+ return True
175
+
176
+
177
+ @contextmanager
178
+ def no_cuda_graph():
179
+ """Deactivate CUDA Graphing for all the calls in this context manager."""
180
+ global _disable_cuda_graph
181
+ old_value = _disable_cuda_graph
182
+ _disable_cuda_graph = True
183
+ try:
184
+ yield
185
+ finally:
186
+ _disable_cuda_graph = old_value
187
+
188
+
189
+ class CUDAGraphed:
190
+ """Allow simple CUDA Graphing of a function.
191
+
192
+ Args:
193
+ func: callable, taking any number of arguments. Its tensors arguments should
194
+ be top level args, not nested in structures (tuples, dicts, etc). Keyword
195
+ arguments are NOT supported for simplicity.
196
+ warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
197
+ allows torch.compiled functions to get properly compiled.
198
+ disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
199
+ """
200
+
201
+ def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
202
+ self.func = func
203
+ self.warmup_steps = warmup_steps
204
+ self.disable = disable
205
+ self._graph: cuda.CUDAGraph | None = None
206
+ self._output: tuple | None = None
207
+ self._args: tuple | None = None
208
+
209
+ def reset(self, warmup_steps: int = 0) -> None:
210
+ """Reset the state, meaning the next call we get CUDA Graphed again. Useful if some
211
+ shapes have changed, or external state (e.g. KVCache) has changed."""
212
+ self.warmup_steps = warmup_steps
213
+ self._graph = None
214
+ self._output = None
215
+ self._args = None
216
+
217
+ def __call__(self, *args, **kwargs) -> tp.Any:
218
+ if kwargs:
219
+ raise RuntimeError("Named arguments not supported for now.")
220
+ if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
221
+ return self.func(*args, **kwargs)
222
+
223
+ def _clone_tensors(args: tuple) -> tuple:
224
+ out: list = []
225
+ for arg in args:
226
+ if isinstance(arg, torch.Tensor):
227
+ arg = arg.clone()
228
+ out.append(arg)
229
+ return tuple(out)
230
+
231
+ def _match_values_copy_tensors(args: tuple, target_args: tuple) -> None:
232
+ if len(args) != len(target_args):
233
+ raise ValueError(
234
+ f"Expected {len(target_args)}, but got {args} for CUDA Graphed function."
235
+ )
236
+ for idx, (source, target) in enumerate(zip(args, target_args)):
237
+ if isinstance(target, torch.Tensor):
238
+ if not isinstance(source, torch.Tensor):
239
+ raise ValueError(
240
+ f"Argument #{idx} was a tensor, and is no longer (now {source})."
241
+ )
242
+ if source.shape != target.shape:
243
+ raise ValueError(
244
+ f"Argument #{idx} had shape {target.shape}, but got shae {source.shape}"
245
+ )
246
+ target.copy_(source)
247
+ else:
248
+ if isinstance(source, torch.Tensor):
249
+ raise ValueError(
250
+ f"Argument #{idx} was not a tensor {target}, but is now one."
251
+ )
252
+ if source is not target and source != target:
253
+ raise ValueError(
254
+ f"Argument #{idx} changed value from {target} to {source}."
255
+ )
256
+
257
+ with _set_in_cuda_graph():
258
+ # Prevent any one under us to try and CUDA Graph things.
259
+ if self._graph is None:
260
+ if self.warmup_steps <= 0:
261
+ self._graph = cuda.CUDAGraph()
262
+ # Making a copy just to ensure those are not used else where.
263
+ self._args = _clone_tensors(args)
264
+ with cuda.graph(self._graph):
265
+ self._output = self.func(*self._args)
266
+ # At this point nothing really happened, so we have to make it run for real.
267
+ self._graph.replay()
268
+ return self._output
269
+ else:
270
+ self.warmup_steps -= 1
271
+ return self.func(*args)
272
+ else:
273
+ assert self._args is not None
274
+ assert self._output is not None
275
+ _match_values_copy_tensors(args, self._args)
276
+ self._graph.replay()
277
+ return self._output
278
+
279
+
280
+ def cuda_graph(func: tp.Callable, warmup_steps: int = 1):
281
+ """Just calls `CUDAGraphed` on the given function."""
282
+ if not _is_cuda_graph_enabled():
283
+ return func
284
+ return CUDAGraphed(func, warmup_steps)
moshi/utils/sampling.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Kyutai, all rights reserved.
2
+ # This source code is licensed under the license found in the
3
+ # LICENSE file in the root directory of this source tree.
4
+
5
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
6
+ # All rights reserved.
7
+ #
8
+ # This source code is licensed under the license found in the
9
+ # LICENSE file in the root directory of this source tree.
10
+
11
+
12
+ import torch
13
+
14
+
15
+ def multinomial(
16
+ input: torch.Tensor, num_samples: int, replacement=False, *, generator=None
17
+ ):
18
+ """torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
19
+
20
+ Args:
21
+ input (torch.Tensor): The input tensor containing probabilities.
22
+ num_samples (int): Number of samples to draw.
23
+ replacement (bool): Whether to draw with replacement or not.
24
+ Keywords args:
25
+ generator (torch.Generator): A pseudorandom number generator for sampling.
26
+ Returns:
27
+ torch.Tensor: Last dimension contains num_samples indices
28
+ sampled from the multinomial probability distribution
29
+ located in the last dimension of tensor input.
30
+ """
31
+ input_ = input.reshape(-1, input.shape[-1])
32
+ # We should probably be able to remove this once the following PR has landed:
33
+ # https://github.com/pytorch/pytorch/pull/134818/files
34
+ # In the meantime, we specialize the case no-replacement, nsamples=1 so as not
35
+ # to have a synchronization point.
36
+ if replacement or num_samples != 1:
37
+ output_ = torch.multinomial(
38
+ input_,
39
+ num_samples=num_samples,
40
+ replacement=replacement,
41
+ generator=generator,
42
+ )
43
+ else:
44
+ q = torch.empty_like(input_).exponential_(1, generator=generator)
45
+ q = input_ / q
46
+ output_ = q.argmax(dim=-1, keepdim=True)
47
+ output = output_.reshape(*list(input.shape[:-1]), -1)
48
+ return output
49
+
50
+
51
+ def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
52
+ """Sample next token from top K values along the last dimension of the input probs tensor.
53
+
54
+ Args:
55
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
56
+ k (int): The k in “top-k”.
57
+ Returns:
58
+ torch.Tensor: Sampled tokens.
59
+ """
60
+ probs, indices = torch.topk(probs, k, dim=-1)
61
+ next_token = multinomial(probs, num_samples=1)
62
+ next_token = indices.gather(-1, next_token)
63
+ return next_token
64
+
65
+
66
+ def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
67
+ """Sample next token from top P probabilities along the last dimension of the input probs tensor.
68
+
69
+ Args:
70
+ probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
71
+ p (int): The p in “top-p”.
72
+ Returns:
73
+ torch.Tensor: Sampled tokens.
74
+ """
75
+ probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
76
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
77
+ mask = probs_sum - probs_sort > p
78
+ probs_sort *= (~mask).float()
79
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
80
+ next_token = multinomial(probs_sort, num_samples=1)
81
+ next_token = torch.gather(probs_idx, -1, next_token)
82
+ return next_token
83
+
84
+
85
+ def sample_token(
86
+ logits: torch.Tensor,
87
+ use_sampling: bool = False,
88
+ temp: float = 1.0,
89
+ top_k: int = 0,
90
+ top_p: float = 0.0,
91
+ ) -> torch.Tensor:
92
+ """Given logits of shape [*, Card], returns a LongTensor of shape [*]."""
93
+ # Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
94
+ if use_sampling and temp > 0.0:
95
+ probs = torch.softmax(logits / temp, dim=-1)
96
+ if top_p > 0.0:
97
+ next_token = sample_top_p(probs, p=top_p)
98
+ elif top_k > 0:
99
+ next_token = sample_top_k(probs, k=top_k)
100
+ else:
101
+ next_token = multinomial(probs, num_samples=1)
102
+ else:
103
+ next_token = torch.argmax(logits, dim=-1, keepdim=True)
104
+ assert next_token.shape[-1] == 1
105
+ return next_token[..., 0]
106
+
107
+
108
+ if __name__ == "__main__":
109
+ torch.manual_seed(1234)
110
+ device = "cpu"
111
+ if torch.cuda.is_available():
112
+ torch.backends.cuda.matmul.allow_tf32 = False
113
+ torch.backends.cudnn.allow_tf32 = False
114
+ device = "cuda:0"
115
+
116
+ ps = torch.tensor([5.0, 2.0, 12.0, 6.0, 8.0, 1.0, 0.0, 4.0], device=device)
117
+ cnts = torch.zeros(ps.shape, dtype=torch.long, device=device)
118
+ total_samples = 1000
119
+ for _ in range(total_samples):
120
+ vs = multinomial(ps, num_samples=1, replacement=False)
121
+ cnts[vs] += 1
122
+ diff = cnts / cnts.sum() - ps / ps.sum()
123
+ max_diff = diff.abs().max().cpu().item()
124
+ print(ps / ps.sum())
125
+ print(cnts / cnts.sum())
126
+ assert max_diff < 1.5e-2
pyproject.toml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools>=61.0"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "moshi-hf"
7
+ version = "0.1.0"
8
+ description = "Moshi HuggingFace inference server"
9
+ requires-python = ">=3.8"
10
+ dependencies = [
11
+ "torch==2.4.1",
12
+ "torchaudio==2.4.1",
13
+ "torchvision==0.19.1",
14
+ "torchdata==0.10.0",
15
+ "transformers==4.46.3",
16
+ "huggingface-hub==0.27.1",
17
+ "safetensors==0.5.1",
18
+ "accelerate>=0.20.0",
19
+ "datasets==3.1.0",
20
+ "requests==2.32.3",
21
+ "urllib3==2.2.3",
22
+ "pyyaml==6.0.2",
23
+ "einops>=0.6.1",
24
+ "ulid-py==1.1.0",
25
+ "tqdm>=4.65.0",
26
+ "sentencepiece==0.2.0",
27
+ "jiwer==3.0.5",
28
+ "numpy==1.24.4",
29
+ "pandas==2.0.3",
30
+ "scikit-learn==1.5.2",
31
+ "pydub==0.25.1",
32
+ "librosa==0.10.2.post1",
33
+ "pyannote.audio==3.1.1",
34
+ "pesq==0.0.4",
35
+ "torchmetrics==1.6.0",
36
+ "uvicorn==0.25.0",
37
+ "fastapi==0.104.1",
38
+ "pydantic==2.5.2"
39
+ ]
requirements.txt ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PyTorch ecosystem with fixed versions
2
+ torch==2.4.1
3
+ torchaudio==2.4.1
4
+ torchvision==0.19.1
5
+ torchdata==0.10.0
6
+
7
+ # HuggingFace ecosystem
8
+ transformers==4.46.3
9
+ huggingface-hub==0.27.1
10
+ safetensors==0.5.1
11
+ accelerate>=0.20.0
12
+ datasets==3.1.0
13
+
14
+ # HTTP and utils
15
+ requests==2.32.3
16
+ urllib3==2.2.3
17
+ pyyaml==6.0.2
18
+ einops>=0.6.1
19
+ ulid-py==1.1.0
20
+ tqdm>=4.65.0
21
+
22
+ # NLP and text processing
23
+ sentencepiece==0.2.0
24
+ jiwer==3.0.5
25
+
26
+ # Data processing
27
+ numpy==1.24.4
28
+ pandas==2.0.3
29
+ scikit-learn==1.5.2
30
+
31
+ # Audio processing
32
+ pydub==0.25.1
33
+ librosa==0.10.2.post1
34
+ pyannote.audio==3.1.1
35
+ pesq==0.0.4
36
+
37
+ # Metrics and evaluation
38
+ torchmetrics==1.6.0
39
+
40
+ # Server
41
+ uvicorn==0.25.0
42
+ fastapi==0.104.1
43
+ pydantic==2.5.2
server.py ADDED
@@ -0,0 +1,121 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, HTTPException
2
+ import numpy as np
3
+ import torch
4
+ from pydantic import BaseModel
5
+ import base64
6
+ import io
7
+ import os
8
+ import logging
9
+ from pathlib import Path
10
+ from inference import InferenceRecipe
11
+ from fastapi.middleware.cors import CORSMiddleware
12
+
13
+ logging.basicConfig(level=logging.INFO)
14
+ logger = logging.getLogger(__name__)
15
+
16
+ app = FastAPI()
17
+
18
+ # Add CORS middleware
19
+ app.add_middleware(
20
+ CORSMiddleware,
21
+ allow_origins=["*"],
22
+ allow_credentials=True,
23
+ allow_methods=["*"],
24
+ allow_headers=["*"],
25
+ )
26
+
27
+ class AudioRequest(BaseModel):
28
+ audio_data: str
29
+ sample_rate: int
30
+
31
+ class AudioResponse(BaseModel):
32
+ audio_data: str
33
+ text: str = ""
34
+
35
+ # Model initialization status
36
+ INITIALIZATION_STATUS = {
37
+ "model_loaded": False,
38
+ "error": None
39
+ }
40
+
41
+ # Global model instance
42
+ model = None
43
+
44
+ def initialize_model():
45
+ """Initialize the model from mounted directory"""
46
+ global model, INITIALIZATION_STATUS
47
+ try:
48
+ device = "cuda" if torch.cuda.is_available() else "cpu"
49
+ logger.info(f"Initializing model on device: {device}")
50
+
51
+ model_path = os.getenv("MODEL_PATH", "/app/models")
52
+ if not os.path.exists(model_path):
53
+ raise RuntimeError(f"Model path {model_path} does not exist")
54
+
55
+ model = InferenceRecipe(model_path, device=device)
56
+ INITIALIZATION_STATUS["model_loaded"] = True
57
+ logger.info("Model initialized successfully")
58
+ return True
59
+ except Exception as e:
60
+ INITIALIZATION_STATUS["error"] = str(e)
61
+ logger.error(f"Failed to initialize model: {e}")
62
+ return False
63
+
64
+ @app.on_event("startup")
65
+ async def startup_event():
66
+ """Initialize model on startup"""
67
+ initialize_model()
68
+
69
+ @app.get("/api/v1/health")
70
+ def health_check():
71
+ """Health check endpoint"""
72
+ status = {
73
+ "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
74
+ "gpu_available": torch.cuda.is_available(),
75
+ "initialization_status": INITIALIZATION_STATUS
76
+ }
77
+
78
+ if model is not None:
79
+ status.update({
80
+ "device": str(model.device),
81
+ "model_path": str(model.model_path),
82
+ "mimi_loaded": model.mimi is not None,
83
+ "tokenizer_loaded": model.text_tokenizer is not None,
84
+ "lm_loaded": model.lm_gen is not None
85
+ })
86
+
87
+ return status
88
+
89
+ @app.post("/api/v1/inference")
90
+ async def inference(request: AudioRequest) -> AudioResponse:
91
+ """Run inference on audio input"""
92
+ if not INITIALIZATION_STATUS["model_loaded"]:
93
+ raise HTTPException(
94
+ status_code=503,
95
+ detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
96
+ )
97
+
98
+ try:
99
+ # Decode audio from base64
100
+ audio_bytes = base64.b64decode(request.audio_data)
101
+ audio_array = np.load(io.BytesIO(audio_bytes))
102
+
103
+ # Run inference
104
+ result = model.inference(audio_array, request.sample_rate)
105
+
106
+ # Encode output audio
107
+ buffer = io.BytesIO()
108
+ np.save(buffer, result['audio'])
109
+ audio_b64 = base64.b64encode(buffer.getvalue()).decode()
110
+
111
+ return AudioResponse(
112
+ audio_data=audio_b64,
113
+ text=result.get("text", "")
114
+ )
115
+ except Exception as e:
116
+ logger.error(f"Inference failed: {str(e)}")
117
+ raise HTTPException(status_code=500, detail=str(e))
118
+
119
+ if __name__ == "__main__":
120
+ import uvicorn
121
+ uvicorn.run(app, host="0.0.0.0", port=8000)
setup.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # The MIT License (MIT)
2
+ # Copyright © 2023 Yuma Rao
3
+ # Copyright © 2024 Omega Labs, Inc.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
6
+ # documentation files (the “Software”), to deal in the Software without restriction, including without limitation
7
+ # the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
8
+ # and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
9
+
10
+ # The above copyright notice and this permission notice shall be included in all copies or substantial portions of
11
+ # the Software.
12
+
13
+ # THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
14
+ # THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
15
+ # THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
16
+ # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
17
+ # DEALINGS IN THE SOFTWARE.
18
+
19
+ import re
20
+ import os
21
+ import codecs
22
+ from os import path
23
+ from io import open
24
+ from setuptools import setup, find_packages
25
+
26
+
27
+ def read_requirements(path):
28
+ with open(path, "r") as f:
29
+ requirements = f.read().splitlines()
30
+ return requirements
31
+
32
+
33
+ requirements = read_requirements("requirements.txt")
34
+ here = path.abspath(path.dirname(__file__))
35
+
36
+ with open(path.join(here, "README.md"), encoding="utf-8") as f:
37
+ long_description = f.read()
38
+
39
+ # loading version from setup.py
40
+ with codecs.open(
41
+ os.path.join(here, "template/__init__.py"), encoding="utf-8"
42
+ ) as init_file:
43
+ version_match = re.search(
44
+ r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M
45
+ )
46
+ version_string = version_match.group(1)
47
+
48
+ setup(
49
+ name="omegalabs-anytoany-bittensor",
50
+ version=version_string,
51
+ description="omegalabs-anytoany-bittensor",
52
+ long_description=long_description,
53
+ long_description_content_type="text/markdown",
54
+ url="https://github.com/omegalabsinc/omegalabs-anytoany-bittensor",
55
+ author="OMEGA Labs, Inc.",
56
+ packages=find_packages(),
57
+ include_package_data=True,
58
+ author_email="[email protected]",
59
+ license="MIT",
60
+ python_requires=">=3.8",
61
+ install_requires=requirements,
62
+ classifiers=[
63
+ "Development Status :: 3 - Alpha",
64
+ "Intended Audience :: Developers",
65
+ "Topic :: Software Development :: Build Tools",
66
+ "License :: OSI Approved :: MIT License",
67
+ "Programming Language :: Python :: 3 :: Only",
68
+ "Programming Language :: Python :: 3.8",
69
+ "Programming Language :: Python :: 3.9",
70
+ "Programming Language :: Python :: 3.10",
71
+ "Topic :: Scientific/Engineering",
72
+ "Topic :: Scientific/Engineering :: Mathematics",
73
+ "Topic :: Scientific/Engineering :: Artificial Intelligence",
74
+ "Topic :: Software Development",
75
+ "Topic :: Software Development :: Libraries",
76
+ "Topic :: Software Development :: Libraries :: Python Modules",
77
+ ],
78
+ )