Upload folder using huggingface_hub
Browse files- Dockerfile +30 -0
- README.md +32 -0
- chunk_silence.py +53 -0
- config.yaml +6 -0
- hotkey.txt +1 -0
- inference.py +244 -0
- mimi_tokenizer.py +48 -0
- models/model.safetensors +3 -0
- models/tokenizer-e351c8d8-checkpoint125.safetensors +3 -0
- models/tokenizer_spm_32k_3.model +3 -0
- moshi/chunk_silence.py +53 -0
- moshi/models/__init__.py +14 -0
- moshi/models/compression.py +474 -0
- moshi/models/lm.py +487 -0
- moshi/models/loaders.py +159 -0
- moshi/modules/__init__.py +23 -0
- moshi/modules/conv.py +329 -0
- moshi/modules/gating.py +82 -0
- moshi/modules/resample.py +119 -0
- moshi/modules/rope.py +90 -0
- moshi/modules/seanet.py +395 -0
- moshi/modules/streaming.py +363 -0
- moshi/modules/transformer.py +750 -0
- moshi/quantization/__init__.py +13 -0
- moshi/quantization/base.py +170 -0
- moshi/quantization/core_vq.py +384 -0
- moshi/quantization/vq.py +340 -0
- moshi/utils/__init__.py +10 -0
- moshi/utils/autocast.py +45 -0
- moshi/utils/compile.py +284 -0
- moshi/utils/sampling.py +126 -0
- pyproject.toml +39 -0
- requirements.txt +43 -0
- server.py +121 -0
- setup.py +78 -0
Dockerfile
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM pytorch/pytorch:2.1.2-cuda12.1-cudnn8-runtime
|
2 |
+
|
3 |
+
WORKDIR /app
|
4 |
+
|
5 |
+
# Install system dependencies
|
6 |
+
RUN apt-get update && \
|
7 |
+
DEBIAN_FRONTEND=noninteractive apt-get install -y --no-install-recommends \
|
8 |
+
build-essential \
|
9 |
+
python3-dev \
|
10 |
+
git \
|
11 |
+
ffmpeg \
|
12 |
+
libsndfile1 \
|
13 |
+
&& rm -rf /var/lib/apt/lists/*
|
14 |
+
|
15 |
+
# Pre-install critical dependencies
|
16 |
+
RUN pip install --no-cache-dir --upgrade pip setuptools wheel
|
17 |
+
|
18 |
+
# Copy entire directory
|
19 |
+
COPY . .
|
20 |
+
|
21 |
+
# Install package in editable mode
|
22 |
+
RUN pip install -e .
|
23 |
+
|
24 |
+
# Environment variables
|
25 |
+
ENV MODEL_PATH=/app/models \
|
26 |
+
PYTHONUNBUFFERED=1
|
27 |
+
|
28 |
+
EXPOSE 8000
|
29 |
+
|
30 |
+
CMD ["python", "server.py"]
|
README.md
ADDED
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
license: mit
|
3 |
+
tags:
|
4 |
+
- any-to-any
|
5 |
+
- omega
|
6 |
+
- omegalabs
|
7 |
+
- bittensor
|
8 |
+
- agi
|
9 |
+
---
|
10 |
+
|
11 |
+
This is an Any-to-Any model checkpoint for the OMEGA Labs x Bittensor Any-to-Any subnet.
|
12 |
+
|
13 |
+
Check out the [git repo](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor) and find OMEGA on X: [@omegalabsai](https://x.com/omegalabsai).
|
14 |
+
|
15 |
+
## License
|
16 |
+
|
17 |
+
This project is licensed under the MIT License - see the LICENSE file for details.
|
18 |
+
|
19 |
+
## Links
|
20 |
+
|
21 |
+
- GitHub Repository: [omegalabs-anytoany-bittensor](https://github.com/omegalabsinc/omegalabs-anytoany-bittensor)
|
22 |
+
- OMEGA Labs on X: [@omegalabsai](https://x.com/omegalabsai)
|
23 |
+
|
24 |
+
## Contributing
|
25 |
+
|
26 |
+
Contributions are welcome! Please feel free to submit a Pull Request.
|
27 |
+
|
28 |
+
## Support
|
29 |
+
|
30 |
+
For support and questions, please:
|
31 |
+
1. Open an issue on GitHub
|
32 |
+
2. Follow OMEGA Labs on X [@omegalabsai](https://x.com/omegalabsai)
|
chunk_silence.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydub import AudioSegment
|
3 |
+
from pydub.silence import detect_silence
|
4 |
+
import glob
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def detect_and_trim_silence(audio_array, frame_rate, min_silence_duration=1000, silence_threshold=-40):
|
9 |
+
# Load the audio file
|
10 |
+
audio_array = audio_array.detach().cpu().numpy()
|
11 |
+
audio_bytes = (audio_array * 32767).astype(np.int16).tobytes()
|
12 |
+
|
13 |
+
# Create AudioSegment from raw audio bytes
|
14 |
+
audio = AudioSegment(
|
15 |
+
data=audio_bytes,
|
16 |
+
sample_width=2, # 16-bit audio = 2 bytes
|
17 |
+
frame_rate=frame_rate,
|
18 |
+
channels=1 # Mono audio
|
19 |
+
)
|
20 |
+
|
21 |
+
# Detect silence
|
22 |
+
silence_intervals = detect_silence(
|
23 |
+
audio,
|
24 |
+
min_silence_len=min_silence_duration,
|
25 |
+
silence_thresh=silence_threshold
|
26 |
+
)
|
27 |
+
|
28 |
+
# Convert milliseconds to seconds
|
29 |
+
silence_intervals_seconds = [(start / 1000, end / 1000) for start, end in silence_intervals]
|
30 |
+
|
31 |
+
first_silence_end = silence_intervals_seconds[0][1]
|
32 |
+
# Create audio from first silence end to end of audio
|
33 |
+
trimmed_audio = audio_array[:,:,int(first_silence_end * frame_rate):] # Slice audio from first silence end to the end
|
34 |
+
|
35 |
+
# trimmed_audio.export(output_path, format="wav")
|
36 |
+
return trimmed_audio
|
37 |
+
|
38 |
+
def process_all_audio_files(root_dir, output_dir):
|
39 |
+
# Use glob to find all .wav files in the directory and its subdirectories
|
40 |
+
wav_files = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
|
41 |
+
for input_path in tqdm(wav_files, desc="Processing audio files"):
|
42 |
+
relative_path = os.path.relpath(input_path, root_dir)
|
43 |
+
output_dir = os.path.join(output_dir, os.path.dirname(relative_path))
|
44 |
+
os.makedirs(output_dir, exist_ok=True)
|
45 |
+
output_path = os.path.join(output_dir, os.path.basename(input_path))
|
46 |
+
detect_and_trim_silence(input_path, output_path)
|
47 |
+
|
48 |
+
# Use the function to process all audio files
|
49 |
+
if __name__ == "__main__":
|
50 |
+
root_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/audio/"
|
51 |
+
output_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/trimmed/"
|
52 |
+
process_all_audio_files(root_directory, output_directory)
|
53 |
+
|
config.yaml
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: moshi
|
3 |
+
version: 1.0
|
4 |
+
description: "Moshi Pretrained Model"
|
5 |
+
author: "Tezuesh"
|
6 |
+
license: "MIT"
|
hotkey.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
5FeqmebkCWfepQPgSkrEHRwtpUmHGASF4BNERZDs9pvKFtcD
|
inference.py
ADDED
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
import torchaudio
|
4 |
+
import sentencepiece
|
5 |
+
import logging
|
6 |
+
from pathlib import Path
|
7 |
+
from moshi.models import loaders, LMGen
|
8 |
+
|
9 |
+
logging.basicConfig(level=logging.INFO)
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
class InferenceRecipe:
|
13 |
+
"""Handles model inference for the Any-to-Any model."""
|
14 |
+
|
15 |
+
def __init__(self, model_path: str, device: str='cuda'):
|
16 |
+
"""Initialize the model.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
model_path (str): Path to model directory with pre-downloaded files
|
20 |
+
device (str): Device to run on ('cuda' or 'cpu')
|
21 |
+
"""
|
22 |
+
self.device = torch.device(device)
|
23 |
+
self.model_path = Path(model_path)
|
24 |
+
|
25 |
+
# Set sample rate and frame rate
|
26 |
+
self.sample_rate = 24000 # Based on model config in loaders.py
|
27 |
+
self.frame_rate = 12.5 # Based on model config in loaders.py
|
28 |
+
|
29 |
+
# Initialize all model components
|
30 |
+
logger.info(f"Initializing models from {model_path}")
|
31 |
+
self.mimi, self.text_tokenizer, self.lm_gen = self._initialize_models()
|
32 |
+
logger.info("Model initialization complete")
|
33 |
+
|
34 |
+
def _initialize_models(self):
|
35 |
+
"""Initialize all required model components."""
|
36 |
+
print("Initializing models...")
|
37 |
+
|
38 |
+
try:
|
39 |
+
# Load MIMI model for encoding/decoding
|
40 |
+
mimi_path = self.model_path / loaders.MIMI_NAME
|
41 |
+
if not mimi_path.exists():
|
42 |
+
raise RuntimeError(f"MIMI model not found at {mimi_path}")
|
43 |
+
logger.info(f"Loading MIMI model from {mimi_path}")
|
44 |
+
mimi = loaders.get_mimi(str(mimi_path), device=self.device)
|
45 |
+
mimi.set_num_codebooks(8)
|
46 |
+
|
47 |
+
# Load text tokenizer
|
48 |
+
tokenizer_path = self.model_path / loaders.TEXT_TOKENIZER_NAME
|
49 |
+
if not tokenizer_path.exists():
|
50 |
+
raise RuntimeError(f"Text tokenizer not found at {tokenizer_path}")
|
51 |
+
logger.info(f"Loading text tokenizer from {tokenizer_path}")
|
52 |
+
text_tokenizer = sentencepiece.SentencePieceProcessor(str(tokenizer_path))
|
53 |
+
|
54 |
+
# Load language model
|
55 |
+
moshi_path = self.model_path / loaders.MOSHI_NAME
|
56 |
+
if not moshi_path.exists():
|
57 |
+
raise RuntimeError(f"Language model not found at {moshi_path}")
|
58 |
+
logger.info(f"Loading language model from {moshi_path}")
|
59 |
+
moshi = loaders.get_moshi_lm(str(moshi_path), device=self.device)
|
60 |
+
lm_gen = LMGen(moshi, temp=0.8, temp_text=0.7)
|
61 |
+
|
62 |
+
return mimi, text_tokenizer, lm_gen
|
63 |
+
|
64 |
+
except Exception as e:
|
65 |
+
logger.error(f"Model initialization failed: {str(e)}")
|
66 |
+
raise
|
67 |
+
|
68 |
+
def _load_audio(self, audio_array: np.ndarray, sample_rate: int):
|
69 |
+
"""Load and preprocess audio."""
|
70 |
+
try:
|
71 |
+
# Convert to tensor
|
72 |
+
wav = torch.from_numpy(audio_array).float().unsqueeze(0).unsqueeze(0)
|
73 |
+
|
74 |
+
# Resample if needed
|
75 |
+
if sample_rate != self.sample_rate:
|
76 |
+
logger.info(f"Resampling from {sample_rate} to {self.sample_rate}")
|
77 |
+
wav = torchaudio.transforms.Resample(
|
78 |
+
orig_freq=sample_rate,
|
79 |
+
new_freq=self.sample_rate
|
80 |
+
)(wav)
|
81 |
+
|
82 |
+
# Ensure frame alignment
|
83 |
+
frame_size = int(self.sample_rate / self.frame_rate)
|
84 |
+
orig_length = wav.shape[-1]
|
85 |
+
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
|
86 |
+
if wav.shape[-1] != orig_length:
|
87 |
+
logger.info(f"Trimmed audio from {orig_length} to {wav.shape[-1]} samples for frame alignment")
|
88 |
+
|
89 |
+
return wav
|
90 |
+
|
91 |
+
except Exception as e:
|
92 |
+
logger.error(f"Audio loading failed: {str(e)}")
|
93 |
+
raise
|
94 |
+
|
95 |
+
def _pad_codes(self, all_codes, time_seconds=30):
|
96 |
+
"""Pad codes to minimum length if needed."""
|
97 |
+
try:
|
98 |
+
min_frames = int(time_seconds * self.frame_rate)
|
99 |
+
frame_size = int(self.sample_rate / self.frame_rate)
|
100 |
+
|
101 |
+
if len(all_codes) < min_frames:
|
102 |
+
frames_to_add = min_frames - len(all_codes)
|
103 |
+
logger.info(f"Padding {frames_to_add} frames to reach minimum length")
|
104 |
+
with torch.no_grad(), self.mimi.streaming(batch_size=1):
|
105 |
+
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
|
106 |
+
for _ in range(frames_to_add):
|
107 |
+
additional_code = self.mimi.encode(chunk)
|
108 |
+
all_codes.append(additional_code)
|
109 |
+
|
110 |
+
return all_codes
|
111 |
+
|
112 |
+
except Exception as e:
|
113 |
+
logger.error(f"Code padding failed: {str(e)}")
|
114 |
+
raise
|
115 |
+
|
116 |
+
def _encode_audio(self, wav: torch.Tensor):
|
117 |
+
"""Convert audio to codes."""
|
118 |
+
try:
|
119 |
+
frame_size = int(self.sample_rate / self.frame_rate)
|
120 |
+
all_codes = []
|
121 |
+
|
122 |
+
with torch.no_grad(), self.mimi.streaming(batch_size=1):
|
123 |
+
for offset in range(0, wav.shape[-1], frame_size):
|
124 |
+
frame = wav[:, :, offset: offset + frame_size]
|
125 |
+
codes = self.mimi.encode(frame.to(self.device))
|
126 |
+
assert codes.shape[-1] == 1, f"Expected code shape (*, *, 1), got {codes.shape}"
|
127 |
+
all_codes.append(codes)
|
128 |
+
|
129 |
+
logger.info(f"Encoded {len(all_codes)} frames")
|
130 |
+
return all_codes
|
131 |
+
|
132 |
+
except Exception as e:
|
133 |
+
logger.error(f"Audio encoding failed: {str(e)}")
|
134 |
+
raise
|
135 |
+
|
136 |
+
def _warmup(self):
|
137 |
+
"""Run a warmup pass."""
|
138 |
+
try:
|
139 |
+
frame_size = int(self.sample_rate / self.frame_rate)
|
140 |
+
chunk = torch.zeros(1, 1, frame_size, dtype=torch.float32, device=self.device)
|
141 |
+
codes = self.mimi.encode(chunk)
|
142 |
+
|
143 |
+
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
|
144 |
+
tokens = self.lm_gen.step(codes[:, :, 0:1])
|
145 |
+
if tokens is not None:
|
146 |
+
_ = self.mimi.decode(tokens[:, 1:])
|
147 |
+
|
148 |
+
torch.cuda.synchronize()
|
149 |
+
logger.info("Warmup pass completed")
|
150 |
+
|
151 |
+
except Exception as e:
|
152 |
+
logger.error(f"Warmup failed: {str(e)}")
|
153 |
+
raise
|
154 |
+
|
155 |
+
def _generate(self, all_codes):
|
156 |
+
"""Generate audio and text from codes."""
|
157 |
+
try:
|
158 |
+
out_wav_chunks = []
|
159 |
+
text_output = []
|
160 |
+
|
161 |
+
with torch.no_grad(), self.lm_gen.streaming(1), self.mimi.streaming(1):
|
162 |
+
for i, code in enumerate(all_codes):
|
163 |
+
assert code.shape == (1, 8, 1), f"Expected code shape (1, 8, 1), got {code.shape}"
|
164 |
+
tokens_out = self.lm_gen.step(code.to(self.device))
|
165 |
+
|
166 |
+
if tokens_out is not None:
|
167 |
+
# Generate audio
|
168 |
+
wav_chunk = self.mimi.decode(tokens_out[:, 1:])
|
169 |
+
out_wav_chunks.append(wav_chunk)
|
170 |
+
|
171 |
+
# Generate text if available
|
172 |
+
text_token = tokens_out[0, 0, 0].item()
|
173 |
+
if text_token not in (0, 3):
|
174 |
+
_text = self.text_tokenizer.id_to_piece(text_token)
|
175 |
+
_text = _text.replace("▁", " ")
|
176 |
+
text_output.append(_text)
|
177 |
+
|
178 |
+
if (i + 1) % 100 == 0:
|
179 |
+
logger.info(f"Processed {i + 1}/{len(all_codes)} frames")
|
180 |
+
|
181 |
+
wav = torch.cat(out_wav_chunks, dim=-1)
|
182 |
+
text = ''.join(text_output)
|
183 |
+
|
184 |
+
logger.info(f"Generated {wav.shape[-1]} samples of audio and {len(text)} characters of text")
|
185 |
+
return wav, text
|
186 |
+
|
187 |
+
except Exception as e:
|
188 |
+
logger.error(f"Generation failed: {str(e)}")
|
189 |
+
raise
|
190 |
+
|
191 |
+
def inference(self, audio_array: np.ndarray, sample_rate: int) -> dict:
|
192 |
+
"""Run inference on input audio.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
audio_array (np.ndarray): Input audio as numpy array
|
196 |
+
sample_rate (int): Sample rate of input audio
|
197 |
+
|
198 |
+
Returns:
|
199 |
+
dict: Contains generated audio array and optional transcribed text
|
200 |
+
"""
|
201 |
+
try:
|
202 |
+
logger.info(f"Starting inference on {len(audio_array)} samples at {sample_rate}Hz")
|
203 |
+
|
204 |
+
# Load and preprocess audio
|
205 |
+
wav = self._load_audio(audio_array, sample_rate)
|
206 |
+
wav = wav.to(self.device)
|
207 |
+
|
208 |
+
# Convert to codes
|
209 |
+
all_codes = self._encode_audio(wav)
|
210 |
+
all_codes = self._pad_codes(all_codes)
|
211 |
+
|
212 |
+
# Warmup pass
|
213 |
+
self._warmup()
|
214 |
+
|
215 |
+
# Generate output
|
216 |
+
out_wav, text = self._generate(all_codes)
|
217 |
+
|
218 |
+
# Convert output to numpy
|
219 |
+
output = out_wav.cpu().numpy().squeeze()
|
220 |
+
|
221 |
+
logger.info("Inference completed successfully")
|
222 |
+
return {
|
223 |
+
"audio": output,
|
224 |
+
"text": text
|
225 |
+
}
|
226 |
+
|
227 |
+
except Exception as e:
|
228 |
+
logger.error(f"Inference failed: {str(e)}")
|
229 |
+
raise
|
230 |
+
|
231 |
+
if __name__ == "__main__":
|
232 |
+
# Example usage
|
233 |
+
import librosa
|
234 |
+
|
235 |
+
# Initialize model
|
236 |
+
model = InferenceRecipe("/path/to/models", device="cuda")
|
237 |
+
|
238 |
+
# Load test audio
|
239 |
+
audio, sr = librosa.load("test.wav", sr=None)
|
240 |
+
|
241 |
+
# Run inference
|
242 |
+
result = model.inference(audio, sr)
|
243 |
+
print(f"Generated {len(result['audio'])} samples of audio")
|
244 |
+
print(f"Generated text: {result['text']}")
|
mimi_tokenizer.py
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from moshi import models
|
2 |
+
loaders = models.loaders
|
3 |
+
from huggingface_hub import hf_hub_download
|
4 |
+
import torch
|
5 |
+
from pydub import AudioSegment
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
|
9 |
+
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
|
10 |
+
|
11 |
+
|
12 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
13 |
+
mimi_weight = hf_hub_download(loaders.DEFAULT_REPO, loaders.MIMI_NAME)
|
14 |
+
mimi = loaders.get_mimi(mimi_weight, device=device)
|
15 |
+
|
16 |
+
def encode_audio(mimi, wav, device):
|
17 |
+
frame_size = int(mimi.sample_rate / mimi.frame_rate)
|
18 |
+
all_codes = []
|
19 |
+
with torch.no_grad(), mimi.streaming(batch_size=1):
|
20 |
+
for offset in range(0, wav.shape[-1], frame_size):
|
21 |
+
frame = wav[:, :, offset: offset + frame_size]
|
22 |
+
codes = mimi.encode(frame.to(device))
|
23 |
+
assert codes.shape[-1] == 1, codes.shape
|
24 |
+
all_codes.append(codes)
|
25 |
+
|
26 |
+
return all_codes
|
27 |
+
|
28 |
+
|
29 |
+
def load_audio(wav_path, mimi):
|
30 |
+
audio = AudioSegment.from_wav(wav_path)
|
31 |
+
samples = np.array(audio.get_array_of_samples())
|
32 |
+
samples = samples.astype(np.float32) / (2**15 if audio.sample_width == 2 else 2**31)
|
33 |
+
wav = torch.from_numpy(samples).float().unsqueeze(0).unsqueeze(0)
|
34 |
+
|
35 |
+
if audio.frame_rate != mimi.sample_rate:
|
36 |
+
wav = torch.nn.functional.interpolate(wav, scale_factor=mimi.sample_rate/audio.frame_rate, mode='linear', align_corners=False)
|
37 |
+
|
38 |
+
frame_size = int(mimi.sample_rate / mimi.frame_rate)
|
39 |
+
wav = wav[:, :, :(wav.shape[-1] // frame_size) * frame_size]
|
40 |
+
|
41 |
+
return wav
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
|
47 |
+
|
48 |
+
|
models/model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5b835c664f3830bf453808cbca9bfbcc9de332c328cc01cbffdfbaba2a8838a7
|
3 |
+
size 15375500136
|
models/tokenizer-e351c8d8-checkpoint125.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:09b782f0629851a271227fb9d36db65c041790365f11bbe5d3d59369cf863f50
|
3 |
+
size 384644900
|
models/tokenizer_spm_32k_3.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:78d4336533ddc26f9acf7250d7fb83492152196c6ea4212c841df76933f18d2d
|
3 |
+
size 552778
|
moshi/chunk_silence.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from pydub import AudioSegment
|
3 |
+
from pydub.silence import detect_silence
|
4 |
+
import glob
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
|
8 |
+
def detect_and_trim_silence(audio_array, frame_rate, min_silence_duration=1000, silence_threshold=-40):
|
9 |
+
# Load the audio file
|
10 |
+
audio_array = audio_array.detach().cpu().numpy()
|
11 |
+
audio_bytes = (audio_array * 32767).astype(np.int16).tobytes()
|
12 |
+
|
13 |
+
# Create AudioSegment from raw audio bytes
|
14 |
+
audio = AudioSegment(
|
15 |
+
data=audio_bytes,
|
16 |
+
sample_width=2, # 16-bit audio = 2 bytes
|
17 |
+
frame_rate=frame_rate,
|
18 |
+
channels=1 # Mono audio
|
19 |
+
)
|
20 |
+
|
21 |
+
# Detect silence
|
22 |
+
silence_intervals = detect_silence(
|
23 |
+
audio,
|
24 |
+
min_silence_len=min_silence_duration,
|
25 |
+
silence_thresh=silence_threshold
|
26 |
+
)
|
27 |
+
|
28 |
+
# Convert milliseconds to seconds
|
29 |
+
silence_intervals_seconds = [(start / 1000, end / 1000) for start, end in silence_intervals]
|
30 |
+
|
31 |
+
first_silence_end = silence_intervals_seconds[0][1]
|
32 |
+
# Create audio from first silence end to end of audio
|
33 |
+
trimmed_audio = audio_array[:,:,int(first_silence_end * frame_rate):] # Slice audio from first silence end to the end
|
34 |
+
|
35 |
+
# trimmed_audio.export(output_path, format="wav")
|
36 |
+
return trimmed_audio
|
37 |
+
|
38 |
+
def process_all_audio_files(root_dir, output_dir):
|
39 |
+
# Use glob to find all .wav files in the directory and its subdirectories
|
40 |
+
wav_files = glob.glob(os.path.join(root_dir, '**', '*.wav'), recursive=True)
|
41 |
+
for input_path in tqdm(wav_files, desc="Processing audio files"):
|
42 |
+
relative_path = os.path.relpath(input_path, root_dir)
|
43 |
+
output_dir = os.path.join(output_dir, os.path.dirname(relative_path))
|
44 |
+
os.makedirs(output_dir, exist_ok=True)
|
45 |
+
output_path = os.path.join(output_dir, os.path.basename(input_path))
|
46 |
+
detect_and_trim_silence(input_path, output_path)
|
47 |
+
|
48 |
+
# Use the function to process all audio files
|
49 |
+
if __name__ == "__main__":
|
50 |
+
root_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/audio/"
|
51 |
+
output_directory = "/workspace/tezuesh/omega-v2v/.predictions_warmup/moshi/trimmed/"
|
52 |
+
process_all_audio_files(root_directory, output_directory)
|
53 |
+
|
moshi/models/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
"""
|
5 |
+
Models for the compression model Moshi,
|
6 |
+
"""
|
7 |
+
|
8 |
+
# flake8: noqa
|
9 |
+
from moshi.models.compression import (
|
10 |
+
CompressionModel,
|
11 |
+
MimiModel,
|
12 |
+
)
|
13 |
+
from moshi.models.lm import LMModel, LMGen
|
14 |
+
from moshi.models.loaders import get_mimi, get_moshi_lm
|
moshi/models/compression.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Part of this file is adapted from encodec.py in https://github.com/facebookresearch/audiocraft
|
6 |
+
# released under the following license.
|
7 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
8 |
+
# All rights reserved.
|
9 |
+
#
|
10 |
+
# This source code is licensed under the license found in the
|
11 |
+
# LICENSE file in the root directory of this source tree.
|
12 |
+
"""Compression models or wrapper around existing models. In particular, provides the implementation
|
13 |
+
for Mimi. Also defines the main interface that a model must follow to be usable as an audio tokenizer.
|
14 |
+
"""
|
15 |
+
|
16 |
+
from abc import abstractmethod
|
17 |
+
from contextlib import nullcontext
|
18 |
+
from dataclasses import dataclass
|
19 |
+
import logging
|
20 |
+
import typing as tp
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch import nn
|
24 |
+
|
25 |
+
|
26 |
+
from moshi.quantization import (
|
27 |
+
QuantizedResult,
|
28 |
+
BaseQuantizer,
|
29 |
+
SplitResidualVectorQuantizer,
|
30 |
+
ResidualVectorQuantizer,
|
31 |
+
)
|
32 |
+
from moshi.modules.resample import ConvDownsample1d, ConvTrUpsample1d
|
33 |
+
from moshi.modules.streaming import StreamingModule, State
|
34 |
+
from moshi.utils.compile import no_compile, CUDAGraphed
|
35 |
+
|
36 |
+
|
37 |
+
logger = logging.getLogger()
|
38 |
+
|
39 |
+
|
40 |
+
class CompressionModel(StreamingModule[State]):
|
41 |
+
"""Base API for all compression model that aim at being used as audio tokenizers
|
42 |
+
with a language model.
|
43 |
+
"""
|
44 |
+
|
45 |
+
@abstractmethod
|
46 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult: ...
|
47 |
+
|
48 |
+
@abstractmethod
|
49 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
50 |
+
"""See `MimiModel.encode`."""
|
51 |
+
...
|
52 |
+
|
53 |
+
@abstractmethod
|
54 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
55 |
+
"""See `MimiModel.decode`."""
|
56 |
+
...
|
57 |
+
|
58 |
+
@abstractmethod
|
59 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
60 |
+
"""Decode from the discrete codes to continuous latent space."""
|
61 |
+
...
|
62 |
+
|
63 |
+
@property
|
64 |
+
@abstractmethod
|
65 |
+
def channels(self) -> int: ...
|
66 |
+
|
67 |
+
@property
|
68 |
+
@abstractmethod
|
69 |
+
def frame_rate(self) -> float: ...
|
70 |
+
|
71 |
+
@property
|
72 |
+
@abstractmethod
|
73 |
+
def sample_rate(self) -> int: ...
|
74 |
+
|
75 |
+
@property
|
76 |
+
@abstractmethod
|
77 |
+
def cardinality(self) -> int: ...
|
78 |
+
|
79 |
+
@property
|
80 |
+
@abstractmethod
|
81 |
+
def num_codebooks(self) -> int: ...
|
82 |
+
|
83 |
+
@property
|
84 |
+
@abstractmethod
|
85 |
+
def total_codebooks(self) -> int: ...
|
86 |
+
|
87 |
+
@abstractmethod
|
88 |
+
def set_num_codebooks(self, n: int):
|
89 |
+
"""Set the active number of codebooks used by the quantizer."""
|
90 |
+
...
|
91 |
+
|
92 |
+
|
93 |
+
@dataclass
|
94 |
+
class _MimiState:
|
95 |
+
graphed_tr_enc: CUDAGraphed | None
|
96 |
+
graphed_tr_dec: CUDAGraphed | None
|
97 |
+
|
98 |
+
def reset(self):
|
99 |
+
pass
|
100 |
+
|
101 |
+
|
102 |
+
class MimiModel(CompressionModel[_MimiState]):
|
103 |
+
"""Mimi model operating on the raw waveform.
|
104 |
+
|
105 |
+
Args:
|
106 |
+
encoder (nn.Module): Encoder network.
|
107 |
+
decoder (nn.Module): Decoder network.
|
108 |
+
quantizer (qt.BaseQuantizer): Quantizer network.
|
109 |
+
frame_rate (float): Final frame rate of the quantized representatiopn.
|
110 |
+
encoder_frame_rate (float): frame rate of the encoder model. Note that if `frame_rate != encopder_frame_rate`,
|
111 |
+
the latent will be resampled linearly to match the desired `frame_rate` before and after quantization.
|
112 |
+
sample_rate (int): Audio sample rate.
|
113 |
+
channels (int): Number of audio channels.
|
114 |
+
causal (bool): Whether to use a causal version of the model.
|
115 |
+
encoder_transformer (nn.Module or None): optional transformer for the encoder.
|
116 |
+
decoder_transformer (nn.Module or None): optional transformer for the decoder.
|
117 |
+
resample_method (str): method to use for resampling the latent space before the quantizer.
|
118 |
+
upsample_channel_wise_bug (bool): controls whether the upsampling is channel wise.
|
119 |
+
Defaults to true to reproduce bug in original implementation.
|
120 |
+
freeze_encoder: whether to freeze the encoder weights.
|
121 |
+
freeze_quantizer: whether to freeze the quantizer weights.
|
122 |
+
freeze_quantizer_level: If positive, freeze the quantizer up to this level.
|
123 |
+
torch_compile_encoder_decoder (bool): if True, uses torch.compile on the encoder / decoder.
|
124 |
+
Deactivated by default for training as this is incompatible at the moment with weight norm.
|
125 |
+
See https://github.com/pytorch/pytorch/issues/121902
|
126 |
+
Also this seems to work well with 2.2.0, but completely fail with 2.4.0.
|
127 |
+
"""
|
128 |
+
|
129 |
+
def __init__(
|
130 |
+
self,
|
131 |
+
encoder: nn.Module,
|
132 |
+
decoder: nn.Module,
|
133 |
+
quantizer: BaseQuantizer,
|
134 |
+
frame_rate: float,
|
135 |
+
encoder_frame_rate: float,
|
136 |
+
sample_rate: int,
|
137 |
+
channels: int,
|
138 |
+
causal: bool = False,
|
139 |
+
encoder_transformer: tp.Optional[nn.Module] = None,
|
140 |
+
decoder_transformer: tp.Optional[nn.Module] = None,
|
141 |
+
resample_method: str = "interpolate",
|
142 |
+
upsample_channel_wise_bug: bool = True,
|
143 |
+
freeze_encoder: bool = False,
|
144 |
+
freeze_quantizer: bool = False,
|
145 |
+
freeze_quantizer_level: int = -1,
|
146 |
+
torch_compile_encoder_decoder: bool = False,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
self.encoder = encoder
|
150 |
+
self.decoder = decoder
|
151 |
+
self.encoder_transformer = encoder_transformer
|
152 |
+
self.decoder_transformer = decoder_transformer
|
153 |
+
self.quantizer = quantizer
|
154 |
+
self._frame_rate = frame_rate
|
155 |
+
self._sample_rate = sample_rate
|
156 |
+
self._channels = channels
|
157 |
+
self.encoder_frame_rate = encoder_frame_rate
|
158 |
+
self.torch_compile_encoder_decoder = torch_compile_encoder_decoder
|
159 |
+
|
160 |
+
if freeze_encoder:
|
161 |
+
for p in self.encoder.parameters():
|
162 |
+
p.requires_grad = False
|
163 |
+
if self.encoder_transformer is not None:
|
164 |
+
for p in self.encoder_transformer.parameters():
|
165 |
+
p.requires_grad = False
|
166 |
+
for name, p in self.quantizer.named_parameters():
|
167 |
+
if name.endswith("input_proj.weight"):
|
168 |
+
p.requires_grad = False
|
169 |
+
if freeze_quantizer:
|
170 |
+
self.quantizer.ema_frozen_(True)
|
171 |
+
self.freeze_quantizer = freeze_quantizer
|
172 |
+
self.freeze_quantizer_level = (
|
173 |
+
freeze_quantizer_level
|
174 |
+
if freeze_quantizer_level > 0
|
175 |
+
else self.quantizer.num_codebooks
|
176 |
+
)
|
177 |
+
|
178 |
+
# We will need the dimension for the resampling. In general the encoder will be a SeanetEncoder
|
179 |
+
# which exposes a `dimension` attribute.
|
180 |
+
dimension = encoder.dimension
|
181 |
+
assert isinstance(
|
182 |
+
dimension, int
|
183 |
+
), f"Dimension should be int, got {dimension} of type {type(dimension)}."
|
184 |
+
self.dimension = dimension
|
185 |
+
|
186 |
+
assert resample_method in [
|
187 |
+
"interpolate",
|
188 |
+
"conv",
|
189 |
+
"avg_pool",
|
190 |
+
], f"Invalid resample_method {resample_method}"
|
191 |
+
self.resample_method = resample_method
|
192 |
+
if encoder_frame_rate != frame_rate:
|
193 |
+
assert not (
|
194 |
+
causal and resample_method == "interpolate"
|
195 |
+
), "Cannot interpolate with causal model."
|
196 |
+
if resample_method in ["conv", "avg_pool"]:
|
197 |
+
assert (
|
198 |
+
self.encoder_frame_rate > self.frame_rate
|
199 |
+
), "Cannot upsample with conv."
|
200 |
+
downsample_stride = self.encoder_frame_rate / self.frame_rate
|
201 |
+
assert downsample_stride == int(
|
202 |
+
downsample_stride
|
203 |
+
), f"Only integer strides are supported, got {downsample_stride}"
|
204 |
+
learnt = resample_method == "conv"
|
205 |
+
self.downsample = ConvDownsample1d(
|
206 |
+
int(downsample_stride),
|
207 |
+
dimension=dimension,
|
208 |
+
learnt=learnt,
|
209 |
+
causal=causal,
|
210 |
+
)
|
211 |
+
if freeze_encoder:
|
212 |
+
for p in self.downsample.parameters():
|
213 |
+
p.requires_grad = False
|
214 |
+
self.upsample = ConvTrUpsample1d(
|
215 |
+
int(downsample_stride),
|
216 |
+
dimension=dimension,
|
217 |
+
learnt=learnt,
|
218 |
+
causal=causal,
|
219 |
+
channel_wise=upsample_channel_wise_bug,
|
220 |
+
)
|
221 |
+
|
222 |
+
def _init_streaming_state(self, batch_size: int) -> _MimiState:
|
223 |
+
device = next(self.parameters()).device
|
224 |
+
disable = device.type != 'cuda'
|
225 |
+
graphed_tr_dec = None
|
226 |
+
graphed_tr_enc = None
|
227 |
+
if self.encoder_transformer is not None:
|
228 |
+
graphed_tr_enc = CUDAGraphed(self.encoder_transformer, disable=disable)
|
229 |
+
if self.decoder_transformer is not None:
|
230 |
+
graphed_tr_dec = CUDAGraphed(self.decoder_transformer, disable=disable)
|
231 |
+
return _MimiState(graphed_tr_enc, graphed_tr_dec)
|
232 |
+
|
233 |
+
@property
|
234 |
+
def channels(self) -> int:
|
235 |
+
return self._channels
|
236 |
+
|
237 |
+
@property
|
238 |
+
def frame_rate(self) -> float:
|
239 |
+
return self._frame_rate
|
240 |
+
|
241 |
+
@property
|
242 |
+
def sample_rate(self) -> int:
|
243 |
+
return self._sample_rate
|
244 |
+
|
245 |
+
@property
|
246 |
+
def total_codebooks(self):
|
247 |
+
"""Total number of quantizer codebooks available."""
|
248 |
+
return self.quantizer.total_codebooks
|
249 |
+
|
250 |
+
@property
|
251 |
+
def num_codebooks(self):
|
252 |
+
"""Active number of codebooks used by the quantizer."""
|
253 |
+
return self.quantizer.num_codebooks
|
254 |
+
|
255 |
+
def set_num_codebooks(self, n: int):
|
256 |
+
"""Set the active number of codebooks used by the quantizer."""
|
257 |
+
self.quantizer.set_num_codebooks(n)
|
258 |
+
|
259 |
+
@property
|
260 |
+
def cardinality(self):
|
261 |
+
"""Cardinality of each codebook."""
|
262 |
+
return self.quantizer.cardinality
|
263 |
+
|
264 |
+
def _to_framerate(self, x: torch.Tensor):
|
265 |
+
# Convert from the encoder frame rate to the overall framerate.
|
266 |
+
_, _, length = x.shape
|
267 |
+
frame_rate = self.encoder_frame_rate
|
268 |
+
new_frame_rate = self.frame_rate
|
269 |
+
if frame_rate == new_frame_rate:
|
270 |
+
return x
|
271 |
+
if self.resample_method == "interpolate":
|
272 |
+
target_length = int(length * new_frame_rate / frame_rate)
|
273 |
+
return nn.functional.interpolate(x, size=target_length, mode="linear")
|
274 |
+
else:
|
275 |
+
return self.downsample(x)
|
276 |
+
|
277 |
+
def _to_encoder_framerate(self, x: torch.Tensor):
|
278 |
+
# Convert from overall framerate to the encoder frame rate.
|
279 |
+
_, _, length = x.shape
|
280 |
+
frame_rate = self.encoder_frame_rate
|
281 |
+
new_frame_rate = self.frame_rate
|
282 |
+
if frame_rate == new_frame_rate:
|
283 |
+
return x
|
284 |
+
if self.resample_method == "interpolate":
|
285 |
+
target_length = int(length * new_frame_rate / frame_rate)
|
286 |
+
return nn.functional.interpolate(x, size=target_length, mode="linear")
|
287 |
+
else:
|
288 |
+
return self.upsample(x)
|
289 |
+
|
290 |
+
@property
|
291 |
+
def _context_for_encoder_decoder(self):
|
292 |
+
if self.torch_compile_encoder_decoder:
|
293 |
+
return nullcontext()
|
294 |
+
else:
|
295 |
+
return no_compile()
|
296 |
+
|
297 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
298 |
+
assert x.dim() == 3
|
299 |
+
length = x.shape[-1]
|
300 |
+
extra_metrics: tp.Dict[str, torch.Tensor] = {}
|
301 |
+
|
302 |
+
if self.freeze_quantizer:
|
303 |
+
if isinstance(self.quantizer, SplitResidualVectorQuantizer):
|
304 |
+
self.quantizer.rvq_first.eval()
|
305 |
+
for i in range(
|
306 |
+
self.freeze_quantizer_level - self.quantizer.n_q_semantic
|
307 |
+
):
|
308 |
+
self.quantizer.rvq_rest.vq.layers[i].eval()
|
309 |
+
elif isinstance(self.quantizer, ResidualVectorQuantizer):
|
310 |
+
for i in range(self.freeze_quantizer_level):
|
311 |
+
self.quantizer.vq.layers[i].eval()
|
312 |
+
else:
|
313 |
+
raise ValueError(f"Unsupported quantizer type {type(self.quantizer)}")
|
314 |
+
|
315 |
+
with self._context_for_encoder_decoder:
|
316 |
+
emb = self.encoder(x)
|
317 |
+
if self.encoder_transformer is not None:
|
318 |
+
(emb,) = self.encoder_transformer(emb)
|
319 |
+
emb = self._to_framerate(emb)
|
320 |
+
expected_length = self.frame_rate * length / self.sample_rate
|
321 |
+
# Checking that we have the proper length given the advertised frame rate.
|
322 |
+
assert abs(emb.shape[-1] - expected_length) < 1, (
|
323 |
+
emb.shape[-1],
|
324 |
+
expected_length,
|
325 |
+
)
|
326 |
+
|
327 |
+
q_res = self.quantizer(emb, self.frame_rate)
|
328 |
+
emb = q_res.x
|
329 |
+
emb = self._to_encoder_framerate(emb)
|
330 |
+
if self.decoder_transformer is not None:
|
331 |
+
(emb,) = self.decoder_transformer(emb)
|
332 |
+
|
333 |
+
with self._context_for_encoder_decoder:
|
334 |
+
out = self.decoder(emb)
|
335 |
+
|
336 |
+
# remove extra padding added by the encoder and decoder
|
337 |
+
assert out.shape[-1] >= length, (out.shape[-1], length)
|
338 |
+
out = out[..., :length]
|
339 |
+
|
340 |
+
q_res.x = out
|
341 |
+
q_res.metrics.update(extra_metrics)
|
342 |
+
return q_res
|
343 |
+
|
344 |
+
def _encode_to_unquantized_latent(self, x: torch.Tensor) -> torch.Tensor:
|
345 |
+
"""Projects a batch of waveforms to unquantized latent space.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
x (torch.Tensor): Float tensor of shape [B, C, T].
|
349 |
+
|
350 |
+
Returns:
|
351 |
+
Unquantized embeddings.
|
352 |
+
"""
|
353 |
+
assert (
|
354 |
+
x.dim() == 3
|
355 |
+
), f"CompressionModel._encode_to_unquantized_latent expects audio of shape [B, C, T] but got {x.shape}"
|
356 |
+
state = self._streaming_state
|
357 |
+
with self._context_for_encoder_decoder:
|
358 |
+
emb = self.encoder(x)
|
359 |
+
if self.encoder_transformer is not None:
|
360 |
+
if state is None:
|
361 |
+
(emb,) = self.encoder_transformer(emb)
|
362 |
+
else:
|
363 |
+
assert state.graphed_tr_enc is not None
|
364 |
+
(emb,) = state.graphed_tr_enc(emb)
|
365 |
+
emb = self._to_framerate(emb)
|
366 |
+
return emb
|
367 |
+
|
368 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
369 |
+
"""Encode the given input tensor to quantized representation.
|
370 |
+
|
371 |
+
Args:
|
372 |
+
x (torch.Tensor): Float tensor of shape [B, C, T]
|
373 |
+
|
374 |
+
Returns:
|
375 |
+
codes (torch.Tensor): an int tensor of shape [B, K, T]
|
376 |
+
with K the number of codebooks used and T the timestep.
|
377 |
+
"""
|
378 |
+
emb = self._encode_to_unquantized_latent(x)
|
379 |
+
codes = self.quantizer.encode(emb)
|
380 |
+
return codes
|
381 |
+
|
382 |
+
def encode_to_latent(self, x: torch.Tensor, quantize: bool = True) -> torch.Tensor:
|
383 |
+
"""Projects a batch of waveforms to latent space.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
x (torch.Tensor): Float tensor of shape [B, C, T].
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
Embeddings, either quantized or not.
|
390 |
+
"""
|
391 |
+
emb = self._encode_to_unquantized_latent(x)
|
392 |
+
if not quantize:
|
393 |
+
return emb
|
394 |
+
else:
|
395 |
+
codes = self.quantizer.encode(emb)
|
396 |
+
return self.decode_latent(codes)
|
397 |
+
|
398 |
+
def decode(self, codes: torch.Tensor):
|
399 |
+
"""Decode the given codes to a reconstructed representation.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
codes (torch.Tensor): Int tensor of shape [B, K, T]
|
403 |
+
|
404 |
+
Returns:
|
405 |
+
out (torch.Tensor): Float tensor of shape [B, C, T], the reconstructed audio.
|
406 |
+
"""
|
407 |
+
state = self._streaming_state
|
408 |
+
emb = self.decode_latent(codes)
|
409 |
+
emb = self._to_encoder_framerate(emb)
|
410 |
+
if self.decoder_transformer is not None:
|
411 |
+
if state is None:
|
412 |
+
(emb,) = self.decoder_transformer(emb)
|
413 |
+
else:
|
414 |
+
assert state.graphed_tr_dec is not None
|
415 |
+
(emb,) = state.graphed_tr_dec(emb)
|
416 |
+
with self._context_for_encoder_decoder:
|
417 |
+
out = self.decoder(emb)
|
418 |
+
# out contains extra padding added by the encoder and decoder
|
419 |
+
return out
|
420 |
+
|
421 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
422 |
+
"""Decode from the discrete codes to continuous latent space."""
|
423 |
+
return self.quantizer.decode(codes)
|
424 |
+
|
425 |
+
|
426 |
+
class WrapperCompressionModel(CompressionModel[State]):
|
427 |
+
"""Base API for CompressionModel wrappers that do not depend on external frameworks."""
|
428 |
+
|
429 |
+
def __init__(self, model: CompressionModel):
|
430 |
+
super().__init__()
|
431 |
+
self.model = model
|
432 |
+
|
433 |
+
def forward(self, x: torch.Tensor) -> QuantizedResult:
|
434 |
+
return self.model.forward(x)
|
435 |
+
|
436 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
437 |
+
return self.model.encode(x)
|
438 |
+
|
439 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
440 |
+
return self.model.decode(codes)
|
441 |
+
|
442 |
+
def decode_latent(self, codes: torch.Tensor) -> torch.Tensor:
|
443 |
+
return self.model.decode_latent(codes)
|
444 |
+
|
445 |
+
def set_num_codebooks(self, n: int):
|
446 |
+
self.model.set_num_codebooks(n)
|
447 |
+
|
448 |
+
@property
|
449 |
+
def quantizer(self):
|
450 |
+
return self.model.quantizer
|
451 |
+
|
452 |
+
@property
|
453 |
+
def channels(self) -> int:
|
454 |
+
return self.model.channels
|
455 |
+
|
456 |
+
@property
|
457 |
+
def frame_rate(self) -> float:
|
458 |
+
return self.model.frame_rate
|
459 |
+
|
460 |
+
@property
|
461 |
+
def sample_rate(self) -> int:
|
462 |
+
return self.model.sample_rate
|
463 |
+
|
464 |
+
@property
|
465 |
+
def cardinality(self) -> int:
|
466 |
+
return self.model.cardinality
|
467 |
+
|
468 |
+
@property
|
469 |
+
def num_codebooks(self) -> int:
|
470 |
+
return self.model.num_codebooks
|
471 |
+
|
472 |
+
@property
|
473 |
+
def total_codebooks(self) -> int:
|
474 |
+
return self.model.total_codebooks
|
moshi/models/lm.py
ADDED
@@ -0,0 +1,487 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from functools import partial
|
13 |
+
import logging
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
|
19 |
+
from moshi.utils.sampling import sample_token
|
20 |
+
from moshi.utils.compile import CUDAGraphed
|
21 |
+
from moshi.modules.streaming import StreamingContainer, StreamingModule
|
22 |
+
from moshi.modules.transformer import (
|
23 |
+
StreamingTransformer,
|
24 |
+
create_norm_fn,
|
25 |
+
)
|
26 |
+
|
27 |
+
|
28 |
+
logger = logging.getLogger(__name__)
|
29 |
+
|
30 |
+
|
31 |
+
class ScaledEmbedding(nn.Embedding):
|
32 |
+
"""Boost learning rate for embeddings (with `scale`).
|
33 |
+
|
34 |
+
Args:
|
35 |
+
norm (bool): if True, uses a layer norm after the embedding.
|
36 |
+
zero_idx (int): special value indicating that the output should be exactly 0.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(self, *args, norm: bool = False, zero_idx: int = -1, **kwargs):
|
40 |
+
super().__init__(*args, **kwargs)
|
41 |
+
self.norm = None
|
42 |
+
if norm:
|
43 |
+
self.norm = create_norm_fn("layer_norm", self.embedding_dim)
|
44 |
+
assert zero_idx < 0, "Please use negative values for the zero_idx."
|
45 |
+
self.zero_idx = zero_idx
|
46 |
+
|
47 |
+
def forward(self, input, *args, **kwargs):
|
48 |
+
is_zero = input == self.zero_idx
|
49 |
+
zero = torch.zeros(1, dtype=input.dtype, device=input.device)
|
50 |
+
input = input.clamp(min=0)
|
51 |
+
y = super().forward(input, *args, **kwargs)
|
52 |
+
if self.norm is not None:
|
53 |
+
y = self.norm(y)
|
54 |
+
y = torch.where(is_zero[..., None], zero, y)
|
55 |
+
return y
|
56 |
+
|
57 |
+
|
58 |
+
class LMModel(StreamingContainer):
|
59 |
+
"""Transformer-based language model on multiple streams of codes.
|
60 |
+
|
61 |
+
Args:
|
62 |
+
n_q (int): Number of parallel streams to model as input.
|
63 |
+
dep_q (int): Number of parallel streams to model in the depformer.
|
64 |
+
card (int): Cardinality, vocabulary size.
|
65 |
+
text_card (int): Cardinality of the text vocabulary.
|
66 |
+
dim (int): Dimension of the transformer encoder.
|
67 |
+
num_heads (int): Number of heads for the transformer encoder.
|
68 |
+
hidden_scale (int): Scale for hidden feed forward dimension of the transformer encoder.
|
69 |
+
norm (str): Normalization method.
|
70 |
+
norm_emb (bool): Whether to normalize embeddings.
|
71 |
+
bias_proj (bool): Use bias for output projections.
|
72 |
+
depformer_*: params used for the Depformer Transformer, all the other will be shared.
|
73 |
+
depformer_multi_linear (bool): if True, uses one linear layer per codebook to project the
|
74 |
+
output of the main transformer to the Depformer latent space.
|
75 |
+
depformer_dim_feedforward (int| list[int]| None): If None, defaults to hidden_scale * depformer_dim.
|
76 |
+
existing_text_padding_id (bool): if True, will use a different token for the initial text token, and
|
77 |
+
the text padding token.
|
78 |
+
same_initial (bool): if True, uses the same initial tokens for both text and audio mode.
|
79 |
+
**kwargs: Additional parameters for the transformer encoder.
|
80 |
+
"""
|
81 |
+
|
82 |
+
def __init__(
|
83 |
+
self,
|
84 |
+
delays: tp.List[int] = [0],
|
85 |
+
n_q: int = 8,
|
86 |
+
dep_q: int = 8,
|
87 |
+
card: int = 1024,
|
88 |
+
text_card: int = 32000,
|
89 |
+
dim: int = 128,
|
90 |
+
num_heads: int = 8,
|
91 |
+
hidden_scale: int = 4,
|
92 |
+
norm: str = "layer_norm",
|
93 |
+
norm_emb: bool = False,
|
94 |
+
bias_proj: bool = False,
|
95 |
+
depformer_dim: int = 256,
|
96 |
+
depformer_dim_feedforward: int | list[int] | None = None,
|
97 |
+
depformer_multi_linear: bool = False,
|
98 |
+
depformer_weights_per_step: bool = False,
|
99 |
+
depformer_pos_emb: str = "sin",
|
100 |
+
existing_text_padding_id: tp.Optional[int] = None,
|
101 |
+
context: tp.Optional[int] = None,
|
102 |
+
device=None,
|
103 |
+
dtype=None,
|
104 |
+
**kwargs,
|
105 |
+
):
|
106 |
+
super().__init__()
|
107 |
+
self.n_q = n_q
|
108 |
+
self.dep_q = dep_q
|
109 |
+
self.card = card
|
110 |
+
self.text_card = text_card
|
111 |
+
assert len(delays) == self.num_codebooks, "unexpected number of delays"
|
112 |
+
self.delays = delays
|
113 |
+
self.dim = dim
|
114 |
+
self.existing_text_padding_id = existing_text_padding_id
|
115 |
+
self.context = context
|
116 |
+
kwargs["context"] = context
|
117 |
+
EmbeddingFactory = partial(
|
118 |
+
ScaledEmbedding,
|
119 |
+
norm=norm_emb,
|
120 |
+
device=device,
|
121 |
+
dtype=dtype,
|
122 |
+
zero_idx=self.zero_token_id,
|
123 |
+
)
|
124 |
+
self.emb = nn.ModuleList(
|
125 |
+
[EmbeddingFactory(self.card + 1, dim) for _ in range(n_q)]
|
126 |
+
)
|
127 |
+
# Text card + padding token (if not in the original tokenizer)
|
128 |
+
extra_text = self.existing_text_padding_id is None
|
129 |
+
# Unlike for audio, here we authorize the model to output the special token.
|
130 |
+
self.text_emb = EmbeddingFactory(text_card + 1, dim)
|
131 |
+
self.text_linear = nn.Linear(dim, text_card + extra_text, bias=bias_proj)
|
132 |
+
depformer_prefix = "depformer_"
|
133 |
+
main_kwargs = {
|
134 |
+
k: v for k, v in kwargs.items() if not k.startswith(depformer_prefix)
|
135 |
+
}
|
136 |
+
self.transformer = StreamingTransformer(
|
137 |
+
d_model=dim,
|
138 |
+
num_heads=num_heads,
|
139 |
+
dim_feedforward=int(hidden_scale * dim),
|
140 |
+
norm=norm,
|
141 |
+
device=device,
|
142 |
+
dtype=dtype,
|
143 |
+
**main_kwargs,
|
144 |
+
)
|
145 |
+
self.out_norm = create_norm_fn(norm, dim)
|
146 |
+
self.depformer_multi_linear = depformer_multi_linear
|
147 |
+
kwargs_dep = main_kwargs.copy()
|
148 |
+
kwargs_dep.update(
|
149 |
+
{
|
150 |
+
k.removeprefix(depformer_prefix): v
|
151 |
+
for k, v in kwargs.items()
|
152 |
+
if k.startswith(depformer_prefix)
|
153 |
+
}
|
154 |
+
)
|
155 |
+
kwargs_dep["positional_embedding"] = depformer_pos_emb
|
156 |
+
kwargs_dep["context"] = None
|
157 |
+
if depformer_weights_per_step:
|
158 |
+
kwargs_dep["weights_per_step"] = dep_q
|
159 |
+
if depformer_multi_linear:
|
160 |
+
# One linear layer per codebook to project different informations from the main model.
|
161 |
+
self.depformer_in = nn.ModuleList(
|
162 |
+
[nn.Linear(dim, depformer_dim, bias=False) for _ in range(dep_q)]
|
163 |
+
)
|
164 |
+
else:
|
165 |
+
self.depformer_in = nn.ModuleList(
|
166 |
+
[nn.Linear(dim, depformer_dim, bias=False)]
|
167 |
+
)
|
168 |
+
# Only using up to dep_q - 1 because the last codebook is never an input to Depformer.
|
169 |
+
self.depformer_emb = nn.ModuleList(
|
170 |
+
[EmbeddingFactory(self.card + 1, depformer_dim) for _ in range(dep_q - 1)]
|
171 |
+
)
|
172 |
+
self.depformer_text_emb = EmbeddingFactory(text_card + 1, depformer_dim)
|
173 |
+
if depformer_dim_feedforward is None:
|
174 |
+
depformer_dim_feedforward = int(hidden_scale * depformer_dim)
|
175 |
+
self.depformer = StreamingTransformer(
|
176 |
+
d_model=depformer_dim,
|
177 |
+
dim_feedforward=depformer_dim_feedforward,
|
178 |
+
norm=norm,
|
179 |
+
device=device,
|
180 |
+
dtype=dtype,
|
181 |
+
**kwargs_dep,
|
182 |
+
)
|
183 |
+
self.depformer.set_streaming_propagate(False)
|
184 |
+
dim = depformer_dim # we will directly apply the next linears to the output of the Depformer.
|
185 |
+
|
186 |
+
self.linears = nn.ModuleList(
|
187 |
+
[nn.Linear(dim, self.card, bias=bias_proj) for _ in range(dep_q)]
|
188 |
+
)
|
189 |
+
|
190 |
+
@property
|
191 |
+
def initial_token_id(self) -> int:
|
192 |
+
"""Token id for the start of sequence (audio)."""
|
193 |
+
return self.card
|
194 |
+
|
195 |
+
@property
|
196 |
+
def text_initial_token_id(self) -> int:
|
197 |
+
"""Token id for the start of sequence (text)."""
|
198 |
+
return self.text_card
|
199 |
+
|
200 |
+
@property
|
201 |
+
def text_padding_token_id(self) -> int:
|
202 |
+
"""Token id for text padding."""
|
203 |
+
if self.existing_text_padding_id is None:
|
204 |
+
return self.text_card
|
205 |
+
else:
|
206 |
+
return self.existing_text_padding_id
|
207 |
+
|
208 |
+
@property
|
209 |
+
def end_of_text_padding_id(self) -> int:
|
210 |
+
"""Token id for optionally marking the last padding step for a word."""
|
211 |
+
return 0
|
212 |
+
|
213 |
+
@property
|
214 |
+
def zero_token_id(self) -> int:
|
215 |
+
"""Special value in the input tokens, indicating that no sampling should
|
216 |
+
happen for that value, and no input should be given to the model."""
|
217 |
+
return -1
|
218 |
+
|
219 |
+
@property
|
220 |
+
def ungenerated_token_id(self) -> int:
|
221 |
+
"""Special value that can be provided in the prompt to indicate that this specific
|
222 |
+
value should be predicted and sampled. This allows for partial teacher forcing, by generating
|
223 |
+
one modality, with the other one fixed.
|
224 |
+
"""
|
225 |
+
return -2
|
226 |
+
|
227 |
+
@property
|
228 |
+
def device(self):
|
229 |
+
first_param = next(iter(self.parameters()))
|
230 |
+
return first_param.device
|
231 |
+
|
232 |
+
@property
|
233 |
+
def num_codebooks(self) -> int:
|
234 |
+
return self.n_q + 1
|
235 |
+
|
236 |
+
@property
|
237 |
+
def num_audio_codebooks(self) -> int:
|
238 |
+
return self.n_q
|
239 |
+
|
240 |
+
@property
|
241 |
+
def audio_offset(self) -> int:
|
242 |
+
return 1
|
243 |
+
|
244 |
+
def _get_initial_token(self) -> torch.Tensor:
|
245 |
+
# Returns the initial token that will be fed to the model to predict the very first timestep.
|
246 |
+
# The output shape will be [B, K, 1].
|
247 |
+
device = next(iter(self.parameters())).device
|
248 |
+
zero = torch.full(
|
249 |
+
[1, 1, 1], self.zero_token_id, device=device, dtype=torch.long
|
250 |
+
)
|
251 |
+
special = torch.full_like(zero, self.initial_token_id)
|
252 |
+
|
253 |
+
text_special = torch.full_like(zero, self.text_initial_token_id)
|
254 |
+
audio_token = special
|
255 |
+
text_token = text_special
|
256 |
+
audio_token = audio_token.expand(-1, self.num_audio_codebooks, -1)
|
257 |
+
token = torch.cat([text_token, audio_token], dim=1)
|
258 |
+
return token
|
259 |
+
|
260 |
+
def forward_text(
|
261 |
+
self,
|
262 |
+
sequence: torch.Tensor,
|
263 |
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
264 |
+
B, K, S = sequence.shape
|
265 |
+
assert (
|
266 |
+
K == self.num_codebooks
|
267 |
+
), f"Sequence shape {sequence.shape} must match the number of codebooks."
|
268 |
+
input_sequence = sequence
|
269 |
+
input_ = None
|
270 |
+
for cb_index in range(self.num_audio_codebooks):
|
271 |
+
audio_emb = self.emb[cb_index](
|
272 |
+
input_sequence[:, cb_index + self.audio_offset]
|
273 |
+
)
|
274 |
+
input_ = audio_emb if input_ is None else input_ + audio_emb
|
275 |
+
text_emb = self.text_emb(input_sequence[:, 0])
|
276 |
+
input_ = text_emb if input_ is None else input_ + text_emb
|
277 |
+
transformer_out = self.transformer(input_)
|
278 |
+
|
279 |
+
if self.out_norm:
|
280 |
+
transformer_out = self.out_norm(transformer_out)
|
281 |
+
assert isinstance(transformer_out, torch.Tensor)
|
282 |
+
text_logits = self.text_linear(transformer_out)
|
283 |
+
text_logits = text_logits[:, None]
|
284 |
+
return transformer_out, text_logits
|
285 |
+
|
286 |
+
def forward_depformer(
|
287 |
+
self,
|
288 |
+
depformer_cb_index: int,
|
289 |
+
sequence: torch.Tensor,
|
290 |
+
transformer_out: torch.Tensor,
|
291 |
+
) -> torch.Tensor:
|
292 |
+
B, K, S = sequence.shape
|
293 |
+
assert (
|
294 |
+
K == 1
|
295 |
+
), f"Codebooks for Depformer streaming should be passed 1 by 1, got {K}."
|
296 |
+
assert (
|
297 |
+
S == 1
|
298 |
+
), f"Steps for Depformer streaming should be passed 1 by 1, got {S}."
|
299 |
+
assert (
|
300 |
+
transformer_out.shape[1] == 1
|
301 |
+
), "Transformer out should be a for a single step."
|
302 |
+
last_token_input: tp.Optional[torch.Tensor] = None
|
303 |
+
depformer_input = transformer_out
|
304 |
+
if self.depformer_multi_linear:
|
305 |
+
depformer_input = self.depformer_in[depformer_cb_index](depformer_input)
|
306 |
+
else:
|
307 |
+
depformer_input = self.depformer_in[0](depformer_input)
|
308 |
+
if depformer_cb_index == 0:
|
309 |
+
last_token_input = self.depformer_text_emb(sequence[:, 0])
|
310 |
+
else:
|
311 |
+
last_token_input = self.depformer_emb[depformer_cb_index - 1](
|
312 |
+
sequence[:, 0]
|
313 |
+
)
|
314 |
+
depformer_input = depformer_input + last_token_input
|
315 |
+
assert depformer_input.shape[1] == 1
|
316 |
+
# depformer_input is [B, 1, depformer_dim].
|
317 |
+
# The streaming state of the depformer ensures that the proper layer is run.
|
318 |
+
dep_output = self.depformer(depformer_input)
|
319 |
+
logits = self.linears[depformer_cb_index](dep_output)
|
320 |
+
logits = logits[:, None]
|
321 |
+
assert logits.dim() == 4, logits.shape # [B, Ka, S, card]
|
322 |
+
return logits
|
323 |
+
|
324 |
+
|
325 |
+
@dataclass
|
326 |
+
class _LMGenState:
|
327 |
+
cache: torch.Tensor
|
328 |
+
initial: torch.Tensor
|
329 |
+
graphed_main: CUDAGraphed
|
330 |
+
graphed_depth: CUDAGraphed
|
331 |
+
offset: int = 0
|
332 |
+
|
333 |
+
def reset(self):
|
334 |
+
self.offset = 0
|
335 |
+
|
336 |
+
|
337 |
+
class LMGen(StreamingModule[_LMGenState]):
|
338 |
+
def __init__(
|
339 |
+
self,
|
340 |
+
lm_model: LMModel,
|
341 |
+
use_sampling: bool = True,
|
342 |
+
temp: float = 0.8,
|
343 |
+
temp_text: float = 0.7,
|
344 |
+
top_k: int = 250,
|
345 |
+
top_k_text: int = 25,
|
346 |
+
check: bool = False,
|
347 |
+
):
|
348 |
+
assert not lm_model.training, "generation shouldn't be used in training mode."
|
349 |
+
super().__init__()
|
350 |
+
|
351 |
+
self.lm_model = lm_model
|
352 |
+
self.use_sampling = use_sampling
|
353 |
+
self.temp = temp
|
354 |
+
self.temp_text = temp_text
|
355 |
+
self.top_k = top_k
|
356 |
+
self.top_k_text = top_k_text
|
357 |
+
self.check = check
|
358 |
+
self.max_delay = max(
|
359 |
+
lm_model.delays
|
360 |
+
) # with delays, we need to generate a few more time steps.
|
361 |
+
self.delays_cuda = torch.tensor(
|
362 |
+
lm_model.delays, device=lm_model.device, dtype=torch.long
|
363 |
+
)
|
364 |
+
|
365 |
+
def _init_streaming_state(self, batch_size: int) -> _LMGenState:
|
366 |
+
lm_model = self.lm_model
|
367 |
+
initial = lm_model._get_initial_token()
|
368 |
+
cache = torch.full(
|
369 |
+
(batch_size, self.lm_model.num_codebooks, self.max_delay + 2),
|
370 |
+
lm_model.ungenerated_token_id,
|
371 |
+
device=lm_model.device,
|
372 |
+
dtype=torch.long,
|
373 |
+
)
|
374 |
+
|
375 |
+
disable = lm_model.device.type != 'cuda'
|
376 |
+
graphed_main = CUDAGraphed(lm_model.forward_text, disable=disable)
|
377 |
+
graphed_depth = CUDAGraphed(self.depformer_step, disable=disable)
|
378 |
+
|
379 |
+
return _LMGenState(cache, initial, graphed_main, graphed_depth)
|
380 |
+
|
381 |
+
@torch.no_grad()
|
382 |
+
def step(self, input_tokens: torch.Tensor) -> torch.Tensor | None:
|
383 |
+
state = self._streaming_state
|
384 |
+
if state is None:
|
385 |
+
raise RuntimeError(
|
386 |
+
"You should wrap those calls with a `with lm_gen.streaming(): ...`."
|
387 |
+
)
|
388 |
+
lm_model = self.lm_model
|
389 |
+
|
390 |
+
assert input_tokens.dim() == 3, "Shape should be [B, K, T]."
|
391 |
+
B, Ki, S = input_tokens.shape
|
392 |
+
assert S == 1, "Only support being given steps one by one."
|
393 |
+
needed_tokens = lm_model.num_codebooks - lm_model.dep_q - 1
|
394 |
+
assert (
|
395 |
+
Ki == needed_tokens
|
396 |
+
), f"We expect {needed_tokens} tokens from the user stream, got {Ki}."
|
397 |
+
|
398 |
+
CT = state.cache.shape[2]
|
399 |
+
for q_other in range(input_tokens.shape[1]):
|
400 |
+
k = lm_model.dep_q + 1 + q_other
|
401 |
+
delay = lm_model.delays[k]
|
402 |
+
write_position = (state.offset + delay) % CT
|
403 |
+
state.cache[:, k, write_position : write_position + 1] = input_tokens[
|
404 |
+
:, q_other
|
405 |
+
]
|
406 |
+
|
407 |
+
position = state.offset % CT
|
408 |
+
for k, delay in enumerate(lm_model.delays):
|
409 |
+
# Only for the very beginning, we extend the initial token for the acoustic
|
410 |
+
# token that are delayed, and thus have no good value to take.
|
411 |
+
if state.offset <= delay:
|
412 |
+
state.cache[:, k, position] = state.initial[:, k, 0]
|
413 |
+
input_ = state.cache[:, :, position : position + 1]
|
414 |
+
|
415 |
+
if self.check:
|
416 |
+
# Check that we are not feeding in any value that is not generated yet.
|
417 |
+
assert not (input_ == lm_model.ungenerated_token_id).any(), (
|
418 |
+
state.offset,
|
419 |
+
input_,
|
420 |
+
)
|
421 |
+
assert (input_[:, lm_model.audio_offset :] <= lm_model.card).all(), input_
|
422 |
+
assert (input_[:, :1] <= lm_model.text_card).all()
|
423 |
+
|
424 |
+
transformer_out, text_logits = state.graphed_main(input_)
|
425 |
+
# Shape of text_logits should be [B, K_text=1, T=1, Card_text]
|
426 |
+
text_token = sample_token(
|
427 |
+
text_logits.float(),
|
428 |
+
self.use_sampling,
|
429 |
+
self.temp_text,
|
430 |
+
self.top_k_text,
|
431 |
+
)
|
432 |
+
assert text_token.dim() == 3, text_token.shape
|
433 |
+
assert text_token.shape[2] == 1
|
434 |
+
assert text_token.shape[1] == 1, "Only one text stream supported."
|
435 |
+
text_token = text_token[:, 0, 0] # shape is [B]
|
436 |
+
audio_tokens = state.graphed_depth(text_token, transformer_out)
|
437 |
+
|
438 |
+
# ensure we don't overwrite prompt tokens, we only write over ungenerated tokens
|
439 |
+
state.offset += 1
|
440 |
+
position = state.offset % CT
|
441 |
+
state.cache[:, 0, position] = text_token
|
442 |
+
state.cache[:, 1 : lm_model.dep_q + 1, position] = audio_tokens
|
443 |
+
|
444 |
+
if state.offset <= self.max_delay:
|
445 |
+
return None
|
446 |
+
B = state.cache.shape[0]
|
447 |
+
gen_delays_cuda = self.delays_cuda[: lm_model.dep_q + 1]
|
448 |
+
index = (
|
449 |
+
((state.offset - self.max_delay + gen_delays_cuda) % CT)
|
450 |
+
.view(1, -1, 1)
|
451 |
+
.expand(B, -1, 1)
|
452 |
+
)
|
453 |
+
out = state.cache.gather(dim=2, index=index)
|
454 |
+
return out
|
455 |
+
|
456 |
+
def depformer_step(
|
457 |
+
self,
|
458 |
+
text_token: torch.Tensor,
|
459 |
+
transformer_out: torch.Tensor,
|
460 |
+
) -> torch.Tensor:
|
461 |
+
(B,) = text_token.shape
|
462 |
+
prev_token = text_token
|
463 |
+
lm_model = self.lm_model
|
464 |
+
depformer_tokens: list[torch.Tensor] = []
|
465 |
+
assert not lm_model.depformer.is_streaming
|
466 |
+
with lm_model.depformer.streaming(B):
|
467 |
+
for cb_index in range(lm_model.dep_q):
|
468 |
+
input_ = prev_token[:, None, None]
|
469 |
+
logits = lm_model.forward_depformer(cb_index, input_, transformer_out)
|
470 |
+
next_token = sample_token(
|
471 |
+
logits.float(),
|
472 |
+
self.use_sampling,
|
473 |
+
self.temp,
|
474 |
+
self.top_k,
|
475 |
+
)
|
476 |
+
assert next_token.shape == (B, 1, 1)
|
477 |
+
next_token = next_token[:, 0, 0] # shape is B
|
478 |
+
depformer_tokens.append(next_token)
|
479 |
+
prev_token = next_token
|
480 |
+
|
481 |
+
assert len(depformer_tokens) == lm_model.dep_q, (
|
482 |
+
len(depformer_tokens),
|
483 |
+
lm_model.dep_q,
|
484 |
+
)
|
485 |
+
out = torch.stack(depformer_tokens, dim=1)
|
486 |
+
assert out.shape == (B, lm_model.dep_q), out.shape
|
487 |
+
return out
|
moshi/models/loaders.py
ADDED
@@ -0,0 +1,159 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
"""Retrieves the pretrained models for Moshi and Mimi."""
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
from safetensors.torch import load_model
|
8 |
+
import torch
|
9 |
+
|
10 |
+
from moshi.models.compression import MimiModel
|
11 |
+
from moshi.models.lm import LMModel
|
12 |
+
from moshi.modules import SEANetEncoder, SEANetDecoder, transformer
|
13 |
+
from moshi.quantization import SplitResidualVectorQuantizer
|
14 |
+
|
15 |
+
SAMPLE_RATE = 24000
|
16 |
+
FRAME_RATE = 12.5
|
17 |
+
|
18 |
+
TEXT_TOKENIZER_NAME = 'tokenizer_spm_32k_3.model'
|
19 |
+
MOSHI_NAME = 'model.safetensors'
|
20 |
+
MIMI_NAME = 'tokenizer-e351c8d8-checkpoint125.safetensors'
|
21 |
+
DEFAULT_REPO = 'kyutai/moshiko-pytorch-bf16'
|
22 |
+
|
23 |
+
|
24 |
+
_seanet_kwargs = {
|
25 |
+
"channels": 1,
|
26 |
+
"dimension": 512,
|
27 |
+
"causal": True,
|
28 |
+
"n_filters": 64,
|
29 |
+
"n_residual_layers": 1,
|
30 |
+
"activation": "ELU",
|
31 |
+
"compress": 2,
|
32 |
+
"dilation_base": 2,
|
33 |
+
"disable_norm_outer_blocks": 0,
|
34 |
+
"kernel_size": 7,
|
35 |
+
"residual_kernel_size": 3,
|
36 |
+
"last_kernel_size": 3,
|
37 |
+
# We train using weight_norm but then the weights are pre-processed for inference so
|
38 |
+
# that we can use a normal convolution.
|
39 |
+
"norm": "none",
|
40 |
+
"pad_mode": "constant",
|
41 |
+
"ratios": [8, 6, 5, 4],
|
42 |
+
"true_skip": True,
|
43 |
+
}
|
44 |
+
_quantizer_kwargs = {
|
45 |
+
"dimension": 256,
|
46 |
+
"n_q": 32,
|
47 |
+
"bins": 2048,
|
48 |
+
"input_dimension": _seanet_kwargs["dimension"],
|
49 |
+
"output_dimension": _seanet_kwargs["dimension"],
|
50 |
+
}
|
51 |
+
_transformer_kwargs = {
|
52 |
+
"d_model": _seanet_kwargs["dimension"],
|
53 |
+
"num_heads": 8,
|
54 |
+
"num_layers": 8,
|
55 |
+
"causal": True,
|
56 |
+
"layer_scale": 0.01,
|
57 |
+
"context": 250,
|
58 |
+
"conv_layout": True,
|
59 |
+
"max_period": 10000,
|
60 |
+
"gating": "none",
|
61 |
+
"norm": "layer_norm",
|
62 |
+
"positional_embedding": "rope",
|
63 |
+
"dim_feedforward": 2048,
|
64 |
+
"input_dimension": _seanet_kwargs["dimension"],
|
65 |
+
"output_dimensions": [_seanet_kwargs["dimension"]],
|
66 |
+
}
|
67 |
+
|
68 |
+
_lm_kwargs = {
|
69 |
+
"dim": 4096,
|
70 |
+
"text_card": 32000,
|
71 |
+
"existing_text_padding_id": 3,
|
72 |
+
"n_q": 16,
|
73 |
+
"dep_q": 8,
|
74 |
+
"card": _quantizer_kwargs["bins"],
|
75 |
+
"num_heads": 32,
|
76 |
+
"num_layers": 32,
|
77 |
+
"hidden_scale": 4.125,
|
78 |
+
"causal": True,
|
79 |
+
"layer_scale": None,
|
80 |
+
"context": 3000,
|
81 |
+
"max_period": 10000,
|
82 |
+
"gating": "silu",
|
83 |
+
"norm": "rms_norm_f32",
|
84 |
+
"positional_embedding": "rope",
|
85 |
+
"depformer_dim": 1024,
|
86 |
+
"depformer_dim_feedforward": int(4.125 * 1024),
|
87 |
+
"depformer_num_heads": 16,
|
88 |
+
"depformer_num_layers": 6,
|
89 |
+
"depformer_causal": True,
|
90 |
+
"depformer_layer_scale": None,
|
91 |
+
"depformer_multi_linear": True,
|
92 |
+
"depformer_context": 8,
|
93 |
+
"depformer_max_period": 10000,
|
94 |
+
"depformer_gating": "silu",
|
95 |
+
"depformer_pos_emb": "none",
|
96 |
+
"depformer_weights_per_step": True,
|
97 |
+
"delays": [0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1],
|
98 |
+
}
|
99 |
+
|
100 |
+
|
101 |
+
def _is_safetensors(path: Path | str) -> bool:
|
102 |
+
return Path(path).suffix in (".safetensors", ".sft", ".sfts")
|
103 |
+
|
104 |
+
|
105 |
+
def get_mimi(filename: str | Path,
|
106 |
+
device: torch.device | str = 'cpu') -> MimiModel:
|
107 |
+
"""Return a pretrained Mimi model."""
|
108 |
+
encoder = SEANetEncoder(**_seanet_kwargs)
|
109 |
+
decoder = SEANetDecoder(**_seanet_kwargs)
|
110 |
+
encoder_transformer = transformer.ProjectedTransformer(
|
111 |
+
device=device, **_transformer_kwargs
|
112 |
+
)
|
113 |
+
decoder_transformer = transformer.ProjectedTransformer(
|
114 |
+
device=device, **_transformer_kwargs
|
115 |
+
)
|
116 |
+
quantizer = SplitResidualVectorQuantizer(
|
117 |
+
**_quantizer_kwargs,
|
118 |
+
)
|
119 |
+
model = MimiModel(
|
120 |
+
encoder,
|
121 |
+
decoder,
|
122 |
+
quantizer,
|
123 |
+
channels=1,
|
124 |
+
sample_rate=SAMPLE_RATE,
|
125 |
+
frame_rate=FRAME_RATE,
|
126 |
+
encoder_frame_rate=SAMPLE_RATE / encoder.hop_length,
|
127 |
+
causal=True,
|
128 |
+
resample_method="conv",
|
129 |
+
encoder_transformer=encoder_transformer,
|
130 |
+
decoder_transformer=decoder_transformer,
|
131 |
+
).to(device=device)
|
132 |
+
model.eval()
|
133 |
+
if _is_safetensors(filename):
|
134 |
+
load_model(model, filename)
|
135 |
+
else:
|
136 |
+
pkg = torch.load(filename, "cpu")
|
137 |
+
model.load_state_dict(pkg["model"])
|
138 |
+
model.set_num_codebooks(8)
|
139 |
+
return model
|
140 |
+
|
141 |
+
|
142 |
+
def get_moshi_lm(filename: str | Path,
|
143 |
+
device: torch.device | str = 'cpu') -> LMModel:
|
144 |
+
dtype = torch.bfloat16
|
145 |
+
model = LMModel(
|
146 |
+
device=device,
|
147 |
+
dtype=dtype,
|
148 |
+
**_lm_kwargs,
|
149 |
+
).to(device=device, dtype=dtype)
|
150 |
+
model.eval()
|
151 |
+
if _is_safetensors(filename):
|
152 |
+
load_model(model, filename)
|
153 |
+
else:
|
154 |
+
pkg = torch.load(
|
155 |
+
filename,
|
156 |
+
"cpu",
|
157 |
+
)
|
158 |
+
model.load_state_dict(pkg["fsdp_best_state"]["model"])
|
159 |
+
return model
|
moshi/modules/__init__.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
"""Modules used for building the models."""
|
11 |
+
|
12 |
+
# flake8: noqa
|
13 |
+
from .conv import (
|
14 |
+
NormConv1d,
|
15 |
+
NormConvTranspose1d,
|
16 |
+
StreamingConv1d,
|
17 |
+
StreamingConvTranspose1d,
|
18 |
+
pad_for_conv1d,
|
19 |
+
pad1d,
|
20 |
+
unpad1d,
|
21 |
+
)
|
22 |
+
from .seanet import SEANetEncoder, SEANetDecoder
|
23 |
+
from .transformer import StreamingTransformer
|
moshi/modules/conv.py
ADDED
@@ -0,0 +1,329 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
from dataclasses import dataclass
|
12 |
+
import math
|
13 |
+
import typing as tp
|
14 |
+
import warnings
|
15 |
+
|
16 |
+
import torch
|
17 |
+
from torch import nn
|
18 |
+
from torch.nn import functional as F
|
19 |
+
from torch.nn.utils import weight_norm
|
20 |
+
|
21 |
+
from .streaming import RawStreamingConv1d, RawStreamingConvTranspose1d, StreamingModule
|
22 |
+
|
23 |
+
|
24 |
+
CONV_NORMALIZATIONS = frozenset(["none", "weight_norm"])
|
25 |
+
|
26 |
+
|
27 |
+
class TransposedLayerNorm(nn.Module):
|
28 |
+
"""LayerNorm for [B, C, T] inputs."""
|
29 |
+
|
30 |
+
def __init__(self, **kwargs):
|
31 |
+
super().__init__()
|
32 |
+
self.layer_norm = nn.LayerNorm(**kwargs)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = x.transpose(1, 2)
|
36 |
+
x = self.layer_norm(x)
|
37 |
+
return x.transpose(1, 2)
|
38 |
+
|
39 |
+
|
40 |
+
def apply_parametrization_norm(module: nn.Module, norm: str = "none"):
|
41 |
+
assert norm in CONV_NORMALIZATIONS
|
42 |
+
if norm == "weight_norm":
|
43 |
+
return weight_norm(module)
|
44 |
+
else:
|
45 |
+
# We already check was in CONV_NORMALIZATION, so any other choice
|
46 |
+
# doesn't need reparametrization.
|
47 |
+
return module
|
48 |
+
|
49 |
+
|
50 |
+
def get_extra_padding_for_conv1d(
|
51 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
52 |
+
) -> int:
|
53 |
+
"""See `pad_for_conv1d`."""
|
54 |
+
length = x.shape[-1]
|
55 |
+
n_frames = (length - kernel_size + padding_total) / stride + 1
|
56 |
+
ideal_length = (math.ceil(n_frames) - 1) * stride + (kernel_size - padding_total)
|
57 |
+
return ideal_length - length
|
58 |
+
|
59 |
+
|
60 |
+
def pad_for_conv1d(
|
61 |
+
x: torch.Tensor, kernel_size: int, stride: int, padding_total: int = 0
|
62 |
+
):
|
63 |
+
"""Pad for a convolution to make sure that the last window is full.
|
64 |
+
Extra padding is added at the end. This is required to ensure that we can rebuild
|
65 |
+
an output of the same length, as otherwise, even with padding, some time steps
|
66 |
+
might get removed.
|
67 |
+
For instance, with total padding = 4, kernel size = 4, stride = 2:
|
68 |
+
0 0 1 2 3 4 5 0 0 # (0s are padding)
|
69 |
+
1 2 3 # (output frames of a convolution, last 0 is never used)
|
70 |
+
0 0 1 2 3 4 5 0 # (output of tr. conv., but pos. 5 is going to get removed as padding)
|
71 |
+
1 2 3 4 # once you removed padding, we are missing one time step !
|
72 |
+
"""
|
73 |
+
extra_padding = get_extra_padding_for_conv1d(x, kernel_size, stride, padding_total)
|
74 |
+
return F.pad(x, (0, extra_padding))
|
75 |
+
|
76 |
+
|
77 |
+
def pad1d(
|
78 |
+
x: torch.Tensor,
|
79 |
+
paddings: tp.Tuple[int, int],
|
80 |
+
mode: str = "constant",
|
81 |
+
value: float = 0.0,
|
82 |
+
):
|
83 |
+
"""Tiny wrapper around F.pad, just to allow for reflect padding on small input.
|
84 |
+
If this is the case, we insert extra 0 padding to the right before the reflection happen.
|
85 |
+
"""
|
86 |
+
length = x.shape[-1]
|
87 |
+
padding_left, padding_right = paddings
|
88 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
89 |
+
if mode == "reflect":
|
90 |
+
max_pad = max(padding_left, padding_right)
|
91 |
+
extra_pad = 0
|
92 |
+
if length <= max_pad:
|
93 |
+
extra_pad = max_pad - length + 1
|
94 |
+
x = F.pad(x, (0, extra_pad))
|
95 |
+
padded = F.pad(x, paddings, mode, value)
|
96 |
+
end = padded.shape[-1] - extra_pad
|
97 |
+
return padded[..., :end]
|
98 |
+
else:
|
99 |
+
return F.pad(x, paddings, mode, value)
|
100 |
+
|
101 |
+
|
102 |
+
def unpad1d(x: torch.Tensor, paddings: tp.Tuple[int, int]):
|
103 |
+
"""Remove padding from x, handling properly zero padding. Only for 1d!"""
|
104 |
+
padding_left, padding_right = paddings
|
105 |
+
assert padding_left >= 0 and padding_right >= 0, (padding_left, padding_right)
|
106 |
+
assert (padding_left + padding_right) <= x.shape[-1]
|
107 |
+
end = x.shape[-1] - padding_right
|
108 |
+
return x[..., padding_left:end]
|
109 |
+
|
110 |
+
|
111 |
+
class NormConv1d(nn.Module):
|
112 |
+
"""Wrapper around Conv1d and normalization applied to this conv
|
113 |
+
to provide a uniform interface across normalization approaches.
|
114 |
+
"""
|
115 |
+
|
116 |
+
def __init__(
|
117 |
+
self,
|
118 |
+
*args,
|
119 |
+
causal: bool = False,
|
120 |
+
norm: str = "none",
|
121 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
122 |
+
**kwargs,
|
123 |
+
):
|
124 |
+
super().__init__()
|
125 |
+
self.conv = apply_parametrization_norm(
|
126 |
+
RawStreamingConv1d(*args, **kwargs), norm
|
127 |
+
)
|
128 |
+
self.norm_type = norm
|
129 |
+
|
130 |
+
def forward(self, x):
|
131 |
+
x = self.conv(x)
|
132 |
+
return x
|
133 |
+
|
134 |
+
|
135 |
+
class NormConvTranspose1d(nn.Module):
|
136 |
+
"""Wrapper around ConvTranspose1d and normalization applied to this conv
|
137 |
+
to provide a uniform interface across normalization approaches.
|
138 |
+
"""
|
139 |
+
|
140 |
+
def __init__(
|
141 |
+
self,
|
142 |
+
*args,
|
143 |
+
causal: bool = False,
|
144 |
+
norm: str = "none",
|
145 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
146 |
+
**kwargs,
|
147 |
+
):
|
148 |
+
super().__init__()
|
149 |
+
self.convtr = apply_parametrization_norm(
|
150 |
+
RawStreamingConvTranspose1d(*args, **kwargs), norm
|
151 |
+
)
|
152 |
+
self.norm_type = norm
|
153 |
+
|
154 |
+
def forward(self, x):
|
155 |
+
x = self.convtr(x)
|
156 |
+
return x
|
157 |
+
|
158 |
+
|
159 |
+
@dataclass
|
160 |
+
class _StreamingConv1dState:
|
161 |
+
padding_to_add: int
|
162 |
+
original_padding_to_add: int
|
163 |
+
|
164 |
+
def reset(self):
|
165 |
+
self.padding_to_add = self.original_padding_to_add
|
166 |
+
|
167 |
+
|
168 |
+
class StreamingConv1d(StreamingModule[_StreamingConv1dState]):
|
169 |
+
"""Conv1d with some builtin handling of asymmetric or causal padding
|
170 |
+
and normalization.
|
171 |
+
"""
|
172 |
+
|
173 |
+
def __init__(
|
174 |
+
self,
|
175 |
+
in_channels: int,
|
176 |
+
out_channels: int,
|
177 |
+
kernel_size: int,
|
178 |
+
stride: int = 1,
|
179 |
+
dilation: int = 1,
|
180 |
+
groups: int = 1,
|
181 |
+
bias: bool = True,
|
182 |
+
causal: bool = False,
|
183 |
+
norm: str = "none",
|
184 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
185 |
+
pad_mode: str = "reflect",
|
186 |
+
):
|
187 |
+
super().__init__()
|
188 |
+
# warn user on unusual setup between dilation and stride
|
189 |
+
if stride > 1 and dilation > 1:
|
190 |
+
warnings.warn(
|
191 |
+
"StreamingConv1d has been initialized with stride > 1 and dilation > 1"
|
192 |
+
f" (kernel_size={kernel_size} stride={stride}, dilation={dilation})."
|
193 |
+
)
|
194 |
+
self.conv = NormConv1d(
|
195 |
+
in_channels,
|
196 |
+
out_channels,
|
197 |
+
kernel_size,
|
198 |
+
stride,
|
199 |
+
dilation=dilation,
|
200 |
+
groups=groups,
|
201 |
+
bias=bias,
|
202 |
+
causal=causal,
|
203 |
+
norm=norm,
|
204 |
+
norm_kwargs=norm_kwargs,
|
205 |
+
)
|
206 |
+
self.causal = causal
|
207 |
+
self.pad_mode = pad_mode
|
208 |
+
|
209 |
+
@property
|
210 |
+
def _stride(self) -> int:
|
211 |
+
return self.conv.conv.stride[0]
|
212 |
+
|
213 |
+
@property
|
214 |
+
def _kernel_size(self) -> int:
|
215 |
+
return self.conv.conv.kernel_size[0]
|
216 |
+
|
217 |
+
@property
|
218 |
+
def _effective_kernel_size(self) -> int:
|
219 |
+
dilation = self.conv.conv.dilation[0]
|
220 |
+
return (
|
221 |
+
self._kernel_size - 1
|
222 |
+
) * dilation + 1 # effective kernel size with dilations
|
223 |
+
|
224 |
+
@property
|
225 |
+
def _padding_total(self) -> int:
|
226 |
+
return self._effective_kernel_size - self._stride
|
227 |
+
|
228 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConv1dState:
|
229 |
+
assert self.causal, "streaming is only supported for causal convs"
|
230 |
+
return _StreamingConv1dState(self._padding_total, self._padding_total)
|
231 |
+
|
232 |
+
def forward(self, x):
|
233 |
+
B, C, T = x.shape
|
234 |
+
padding_total = self._padding_total
|
235 |
+
extra_padding = get_extra_padding_for_conv1d(
|
236 |
+
x, self._effective_kernel_size, self._stride, padding_total
|
237 |
+
)
|
238 |
+
state = self._streaming_state
|
239 |
+
if state is None:
|
240 |
+
if self.causal:
|
241 |
+
# Left padding for causal
|
242 |
+
x = pad1d(x, (padding_total, extra_padding), mode=self.pad_mode)
|
243 |
+
else:
|
244 |
+
# Asymmetric padding required for odd strides
|
245 |
+
padding_right = padding_total // 2
|
246 |
+
padding_left = padding_total - padding_right
|
247 |
+
x = pad1d(
|
248 |
+
x, (padding_left, padding_right + extra_padding), mode=self.pad_mode
|
249 |
+
)
|
250 |
+
else:
|
251 |
+
if state.padding_to_add > 0 and x.shape[-1] > 0:
|
252 |
+
x = pad1d(x, (state.padding_to_add, 0), mode=self.pad_mode)
|
253 |
+
state.padding_to_add = 0
|
254 |
+
return self.conv(x)
|
255 |
+
|
256 |
+
|
257 |
+
@dataclass
|
258 |
+
class _StreamingConvTr1dState:
|
259 |
+
pass
|
260 |
+
|
261 |
+
def reset(self):
|
262 |
+
pass
|
263 |
+
|
264 |
+
|
265 |
+
class StreamingConvTranspose1d(StreamingModule[_StreamingConvTr1dState]):
|
266 |
+
"""ConvTranspose1d with some builtin handling of asymmetric or causal padding
|
267 |
+
and normalization.
|
268 |
+
"""
|
269 |
+
|
270 |
+
def __init__(
|
271 |
+
self,
|
272 |
+
in_channels: int,
|
273 |
+
out_channels: int,
|
274 |
+
kernel_size: int,
|
275 |
+
stride: int = 1,
|
276 |
+
groups: int = 1,
|
277 |
+
bias: bool = True,
|
278 |
+
causal: bool = False,
|
279 |
+
norm: str = "none",
|
280 |
+
trim_right_ratio: float = 1.0,
|
281 |
+
norm_kwargs: tp.Dict[str, tp.Any] = {},
|
282 |
+
):
|
283 |
+
super().__init__()
|
284 |
+
self.convtr = NormConvTranspose1d(
|
285 |
+
in_channels,
|
286 |
+
out_channels,
|
287 |
+
kernel_size,
|
288 |
+
stride,
|
289 |
+
groups=groups,
|
290 |
+
bias=bias,
|
291 |
+
causal=causal,
|
292 |
+
norm=norm,
|
293 |
+
norm_kwargs=norm_kwargs,
|
294 |
+
)
|
295 |
+
self.causal = causal
|
296 |
+
self.trim_right_ratio = trim_right_ratio
|
297 |
+
assert (
|
298 |
+
self.causal or self.trim_right_ratio == 1.0
|
299 |
+
), "`trim_right_ratio` != 1.0 only makes sense for causal convolutions"
|
300 |
+
assert self.trim_right_ratio >= 0.0 and self.trim_right_ratio <= 1.0
|
301 |
+
|
302 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTr1dState:
|
303 |
+
assert self.causal, "streaming is only supported for causal convtrs"
|
304 |
+
return _StreamingConvTr1dState()
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
kernel_size = self.convtr.convtr.kernel_size[0]
|
308 |
+
stride = self.convtr.convtr.stride[0]
|
309 |
+
padding_total = kernel_size - stride
|
310 |
+
|
311 |
+
y = self.convtr(x)
|
312 |
+
|
313 |
+
if not self.is_streaming:
|
314 |
+
# We will only trim fixed padding. Extra padding from `pad_for_conv1d` would be
|
315 |
+
# removed at the very end, when keeping only the right length for the output,
|
316 |
+
# as removing it here would require also passing the length at the matching layer
|
317 |
+
# in the encoder.
|
318 |
+
if self.causal:
|
319 |
+
# Trim the padding on the right according to the specified ratio
|
320 |
+
# if trim_right_ratio = 1.0, trim everything from right
|
321 |
+
padding_right = math.ceil(padding_total * self.trim_right_ratio)
|
322 |
+
padding_left = padding_total - padding_right
|
323 |
+
y = unpad1d(y, (padding_left, padding_right))
|
324 |
+
else:
|
325 |
+
# Asymmetric padding required for odd strides
|
326 |
+
padding_right = padding_total // 2
|
327 |
+
padding_left = padding_total - padding_right
|
328 |
+
y = unpad1d(y, (padding_left, padding_right))
|
329 |
+
return y
|
moshi/modules/gating.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
import torch
|
6 |
+
from torch import nn
|
7 |
+
from torch.nn import functional as F
|
8 |
+
|
9 |
+
from ..utils.compile import torch_compile_lazy
|
10 |
+
|
11 |
+
|
12 |
+
@torch_compile_lazy
|
13 |
+
def gating_forward_kernel(
|
14 |
+
weight_in: torch.Tensor, weight_out: torch.Tensor, activation, x: torch.Tensor
|
15 |
+
):
|
16 |
+
x = F.linear(x, weight_in)
|
17 |
+
B, T, _ = x.shape
|
18 |
+
x = x.view(B, T, 2, -1)
|
19 |
+
x = activation(x[..., 0, :]) * x[..., 1, :]
|
20 |
+
x = F.linear(x, weight_out)
|
21 |
+
return x
|
22 |
+
|
23 |
+
|
24 |
+
class ActivationGating(nn.Module):
|
25 |
+
"""
|
26 |
+
Gating FFN layer, using the given activation.
|
27 |
+
Args:
|
28 |
+
dim (int): dimension of the input and output of the transformer.
|
29 |
+
activation (any callable Tensor to Tensor): activation function to use.
|
30 |
+
**factory_kwargs: other kwargs passed to the linear layer, in particular device and dtype.
|
31 |
+
"""
|
32 |
+
|
33 |
+
_fsdp_final = True
|
34 |
+
|
35 |
+
def __init__(self, dim: int, dim_feedforward: int, activation, **factory_kwargs):
|
36 |
+
super().__init__()
|
37 |
+
# We should have 8 d^2 param, instead we will have
|
38 |
+
# 2 * h * d + h * d = 3 h * d = 8 d^2
|
39 |
+
# so h = 8 d / 3 but following Hervé's advice we use 21 / 8 as an approx.
|
40 |
+
if dim_feedforward == 4 * dim:
|
41 |
+
hidden = (21 * dim) // 8
|
42 |
+
else:
|
43 |
+
hidden = (2 * dim_feedforward) // 3
|
44 |
+
self.linear_in = nn.Linear(dim, 2 * hidden, bias=False, **factory_kwargs)
|
45 |
+
self.linear_out = nn.Linear(hidden, dim, bias=False, **factory_kwargs)
|
46 |
+
self.activation = activation
|
47 |
+
|
48 |
+
def forward(self, x: torch.Tensor):
|
49 |
+
return gating_forward_kernel(
|
50 |
+
self.linear_in.weight, self.linear_out.weight, self.activation, x
|
51 |
+
)
|
52 |
+
|
53 |
+
|
54 |
+
def _get_activation(name: str):
|
55 |
+
if name in ["sigmoid", "tanh", "relu"]:
|
56 |
+
return getattr(torch, name)
|
57 |
+
elif name in ["leaky_relu", "elu", "gelu", "silu", "mish", "softsign"]:
|
58 |
+
return getattr(torch.nn.functional, name)
|
59 |
+
elif name == "identity":
|
60 |
+
return torch.nn.Identity()
|
61 |
+
else:
|
62 |
+
raise ValueError(f"Unknown activation {name}")
|
63 |
+
|
64 |
+
|
65 |
+
def _make_gating(
|
66 |
+
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
67 |
+
) -> nn.Module:
|
68 |
+
return ActivationGating(
|
69 |
+
dim, dim_feedforward, _get_activation(name), **factory_kwargs
|
70 |
+
)
|
71 |
+
|
72 |
+
|
73 |
+
def make_gating(
|
74 |
+
name: str, dim: int, dim_feedforward: int, **factory_kwargs
|
75 |
+
) -> nn.Module:
|
76 |
+
gating = _make_gating(name, dim, dim_feedforward, **factory_kwargs)
|
77 |
+
max_params = 2 * dim * dim_feedforward
|
78 |
+
params = sum(p.numel() for p in gating.parameters())
|
79 |
+
assert (
|
80 |
+
params <= max_params
|
81 |
+
), f"{name} gating has {params} params, max is {max_params}"
|
82 |
+
return gating
|
moshi/modules/resample.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
import typing as tp
|
6 |
+
|
7 |
+
from einops import rearrange
|
8 |
+
import torch
|
9 |
+
from torch import nn
|
10 |
+
|
11 |
+
from .conv import StreamingConv1d, StreamingConvTranspose1d
|
12 |
+
|
13 |
+
|
14 |
+
class ConvDownsample1d(nn.Module):
|
15 |
+
"""
|
16 |
+
Downsampling by some integer amount `stride` using convolutions
|
17 |
+
with a kernel size of twice the stride.
|
18 |
+
If `causal` is True, the output uses a causal convolution.
|
19 |
+
"""
|
20 |
+
|
21 |
+
def __init__(
|
22 |
+
self,
|
23 |
+
stride: int,
|
24 |
+
dimension: tp.Optional[int] = None,
|
25 |
+
causal: bool = False,
|
26 |
+
learnt: bool = False,
|
27 |
+
channel_wise: bool = False,
|
28 |
+
):
|
29 |
+
super().__init__()
|
30 |
+
self.learnt = learnt
|
31 |
+
self.channel_wise = channel_wise
|
32 |
+
groups = 1
|
33 |
+
if learnt:
|
34 |
+
assert dimension is not None, "Dimension required for learnt convolutions."
|
35 |
+
in_channels = dimension
|
36 |
+
out_channels = dimension
|
37 |
+
if channel_wise:
|
38 |
+
groups = dimension
|
39 |
+
else:
|
40 |
+
in_channels = 1
|
41 |
+
out_channels = 1
|
42 |
+
|
43 |
+
self.conv = StreamingConv1d(
|
44 |
+
in_channels,
|
45 |
+
out_channels,
|
46 |
+
kernel_size=2 * stride,
|
47 |
+
stride=stride,
|
48 |
+
causal=causal,
|
49 |
+
groups=groups,
|
50 |
+
bias=False,
|
51 |
+
pad_mode="replicate",
|
52 |
+
)
|
53 |
+
if not learnt:
|
54 |
+
actual_conv = self.conv.conv.conv
|
55 |
+
actual_conv.weight.requires_grad_(False)
|
56 |
+
actual_conv.weight.data.fill_(1.0 / (2 * stride))
|
57 |
+
|
58 |
+
def forward(self, x: torch.Tensor):
|
59 |
+
batch_size = len(x)
|
60 |
+
if not self.learnt:
|
61 |
+
x = rearrange(x, "b c t -> (b c) () t")
|
62 |
+
y = self.conv(x)
|
63 |
+
if not self.learnt:
|
64 |
+
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
|
65 |
+
return y
|
66 |
+
|
67 |
+
|
68 |
+
class ConvTrUpsample1d(nn.Module):
|
69 |
+
"""
|
70 |
+
Upsample by some integer amount `stride` using transposed convolutions.
|
71 |
+
"""
|
72 |
+
|
73 |
+
def __init__(
|
74 |
+
self,
|
75 |
+
stride: int,
|
76 |
+
dimension: tp.Optional[int] = None,
|
77 |
+
causal: bool = False,
|
78 |
+
learnt: bool = False,
|
79 |
+
channel_wise: bool = False,
|
80 |
+
):
|
81 |
+
super().__init__()
|
82 |
+
self.learnt = learnt
|
83 |
+
self.channel_wise = channel_wise
|
84 |
+
groups = 1
|
85 |
+
if learnt:
|
86 |
+
assert dimension is not None, "Dimension required for learnt convolutions."
|
87 |
+
in_channels = dimension
|
88 |
+
out_channels = dimension
|
89 |
+
if channel_wise:
|
90 |
+
groups = dimension
|
91 |
+
else:
|
92 |
+
in_channels = 1
|
93 |
+
out_channels = 1
|
94 |
+
|
95 |
+
self.convtr = StreamingConvTranspose1d(
|
96 |
+
in_channels,
|
97 |
+
out_channels,
|
98 |
+
kernel_size=2 * stride,
|
99 |
+
stride=stride,
|
100 |
+
causal=causal,
|
101 |
+
groups=groups,
|
102 |
+
bias=False,
|
103 |
+
)
|
104 |
+
if not learnt:
|
105 |
+
actual_convtr = self.convtr.convtr.convtr
|
106 |
+
actual_convtr.weight.requires_grad_(False)
|
107 |
+
actual_convtr.weight.data.fill_(1.0)
|
108 |
+
|
109 |
+
def forward(self, x: torch.Tensor):
|
110 |
+
batch_size = len(x)
|
111 |
+
if not self.learnt:
|
112 |
+
x = rearrange(x, "b c t -> (b c) () t")
|
113 |
+
y = self.convtr(x)
|
114 |
+
if not self.learnt:
|
115 |
+
x_for_normalization = torch.ones_like(x[:1])
|
116 |
+
normalization = self.convtr(x_for_normalization)
|
117 |
+
y = y / normalization
|
118 |
+
y = rearrange(y, "(b c) () t -> b c t", b=batch_size)
|
119 |
+
return y
|
moshi/modules/rope.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
from torch import nn
|
6 |
+
import math
|
7 |
+
import torch
|
8 |
+
from ..utils.compile import torch_compile_lazy
|
9 |
+
|
10 |
+
|
11 |
+
@torch_compile_lazy
|
12 |
+
def apply_rope(
|
13 |
+
q: torch.Tensor,
|
14 |
+
k: torch.Tensor,
|
15 |
+
offset: torch.Tensor,
|
16 |
+
max_period: float = 10_000,
|
17 |
+
time_before_heads: bool = False,
|
18 |
+
):
|
19 |
+
"""
|
20 |
+
Args:
|
21 |
+
q (torch.Tensor): queries, shape `[B, T, H, D]`.
|
22 |
+
k (torch.Tensor): keys, shape `[B, T, H, D]`.
|
23 |
+
offset (int): current offset, e.g. when streaming.
|
24 |
+
max_period (float): maximum period for the cos and sin.
|
25 |
+
time_before_heads (bool): if True, expected [B, T, H, D], else [B, H, T ,D]
|
26 |
+
"""
|
27 |
+
|
28 |
+
if time_before_heads:
|
29 |
+
B, T, H, D = q.shape
|
30 |
+
else:
|
31 |
+
B, H, T, D = q.shape
|
32 |
+
assert k.shape == q.shape
|
33 |
+
assert D > 0
|
34 |
+
assert D % 2 == 0
|
35 |
+
assert max_period > 0
|
36 |
+
|
37 |
+
ds = torch.arange(D // 2, device=q.device, dtype=torch.float32)
|
38 |
+
freqs = torch.exp(ds * (-math.log(max_period) * 2 / D))
|
39 |
+
ts = offset.float() + torch.arange(T, device=q.device, dtype=torch.float32)
|
40 |
+
if time_before_heads:
|
41 |
+
ts = ts.view(-1, 1, 1)
|
42 |
+
else:
|
43 |
+
ts = ts.view(1, -1, 1)
|
44 |
+
|
45 |
+
dims = q.shape[:-1]
|
46 |
+
q = q.view(*dims, D // 2, 2)
|
47 |
+
k = k.view(*dims, D // 2, 2)
|
48 |
+
|
49 |
+
# convention is `r` suffix is real part, `i` is imaginary.
|
50 |
+
qr = q[..., 0].float()
|
51 |
+
qi = q[..., 1].float()
|
52 |
+
|
53 |
+
kr = k[..., 0].float()
|
54 |
+
ki = k[..., 1].float()
|
55 |
+
|
56 |
+
rotr = torch.cos(freqs * ts)
|
57 |
+
roti = torch.sin(freqs * ts)
|
58 |
+
qor = qr * rotr - qi * roti
|
59 |
+
qoi = qr * roti + qi * rotr
|
60 |
+
|
61 |
+
kor = kr * rotr - ki * roti
|
62 |
+
koi = kr * roti + ki * rotr
|
63 |
+
|
64 |
+
dtype = q.dtype
|
65 |
+
qo = torch.stack([qor.to(dtype), qoi.to(dtype)], dim=-1)
|
66 |
+
ko = torch.stack([kor.to(dtype), koi.to(dtype)], dim=-1)
|
67 |
+
|
68 |
+
return qo.view(*dims, D), ko.view(*dims, D)
|
69 |
+
|
70 |
+
|
71 |
+
class RotaryEmbedding(nn.Module):
|
72 |
+
"""Rotary positional embedding (RoPE) from [Su et al 2022](https://arxiv.org/abs/2104.09864).
|
73 |
+
|
74 |
+
Args:
|
75 |
+
max_period (float): Maximum period of the rotation frequencies.
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, max_period: float = 10000.0):
|
79 |
+
super().__init__()
|
80 |
+
self.max_period = max_period
|
81 |
+
|
82 |
+
def forward(
|
83 |
+
self,
|
84 |
+
q: torch.Tensor,
|
85 |
+
k: torch.Tensor,
|
86 |
+
offset: torch.Tensor,
|
87 |
+
time_before_heads: bool = False,
|
88 |
+
):
|
89 |
+
"""Apply rope rotation to query or key tensor."""
|
90 |
+
return apply_rope(q, k, offset, self.max_period, time_before_heads)
|
moshi/modules/seanet.py
ADDED
@@ -0,0 +1,395 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
from .conv import StreamingConv1d, StreamingConvTranspose1d
|
17 |
+
from .streaming import StreamingContainer, StreamingAdd
|
18 |
+
from ..utils.compile import torch_compile_lazy
|
19 |
+
|
20 |
+
|
21 |
+
class SEANetResnetBlock(StreamingContainer):
|
22 |
+
"""Residual block from SEANet model.
|
23 |
+
|
24 |
+
Args:
|
25 |
+
dim (int): Dimension of the input/output.
|
26 |
+
kernel_sizes (list): List of kernel sizes for the convolutions.
|
27 |
+
dilations (list): List of dilations for the convolutions.
|
28 |
+
activation (str): Activation function.
|
29 |
+
activation_params (dict): Parameters to provide to the activation function.
|
30 |
+
norm (str): Normalization method.
|
31 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
32 |
+
causal (bool): Whether to use fully causal convolution.
|
33 |
+
pad_mode (str): Padding mode for the convolutions.
|
34 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
35 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
36 |
+
(streamable) convolution as the skip connection.
|
37 |
+
"""
|
38 |
+
|
39 |
+
def __init__(
|
40 |
+
self,
|
41 |
+
dim: int,
|
42 |
+
kernel_sizes: tp.List[int] = [3, 1],
|
43 |
+
dilations: tp.List[int] = [1, 1],
|
44 |
+
activation: str = "ELU",
|
45 |
+
activation_params: dict = {"alpha": 1.0},
|
46 |
+
norm: str = "none",
|
47 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
48 |
+
causal: bool = False,
|
49 |
+
pad_mode: str = "reflect",
|
50 |
+
compress: int = 2,
|
51 |
+
true_skip: bool = True,
|
52 |
+
):
|
53 |
+
super().__init__()
|
54 |
+
assert len(kernel_sizes) == len(
|
55 |
+
dilations
|
56 |
+
), "Number of kernel sizes should match number of dilations"
|
57 |
+
act = getattr(nn, activation)
|
58 |
+
hidden = dim // compress
|
59 |
+
block = []
|
60 |
+
for i, (kernel_size, dilation) in enumerate(zip(kernel_sizes, dilations)):
|
61 |
+
in_chs = dim if i == 0 else hidden
|
62 |
+
out_chs = dim if i == len(kernel_sizes) - 1 else hidden
|
63 |
+
block += [
|
64 |
+
act(**activation_params),
|
65 |
+
StreamingConv1d(
|
66 |
+
in_chs,
|
67 |
+
out_chs,
|
68 |
+
kernel_size=kernel_size,
|
69 |
+
dilation=dilation,
|
70 |
+
norm=norm,
|
71 |
+
norm_kwargs=norm_params,
|
72 |
+
causal=causal,
|
73 |
+
pad_mode=pad_mode,
|
74 |
+
),
|
75 |
+
]
|
76 |
+
self.block = nn.Sequential(*block)
|
77 |
+
self.add = StreamingAdd()
|
78 |
+
self.shortcut: nn.Module
|
79 |
+
if true_skip:
|
80 |
+
self.shortcut = nn.Identity()
|
81 |
+
else:
|
82 |
+
self.shortcut = StreamingConv1d(
|
83 |
+
dim,
|
84 |
+
dim,
|
85 |
+
kernel_size=1,
|
86 |
+
norm=norm,
|
87 |
+
norm_kwargs=norm_params,
|
88 |
+
causal=causal,
|
89 |
+
pad_mode=pad_mode,
|
90 |
+
)
|
91 |
+
|
92 |
+
def forward(self, x):
|
93 |
+
u, v = self.shortcut(x), self.block(x)
|
94 |
+
return self.add(u, v)
|
95 |
+
|
96 |
+
|
97 |
+
class SEANetEncoder(StreamingContainer):
|
98 |
+
"""SEANet encoder.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
channels (int): Audio channels.
|
102 |
+
dimension (int): Intermediate representation dimension.
|
103 |
+
n_filters (int): Base width for the model.
|
104 |
+
n_residual_layers (int): nb of residual layers.
|
105 |
+
ratios (Sequence[int]): kernel size and stride ratios. The encoder uses downsampling ratios instead of
|
106 |
+
upsampling ratios, hence it will use the ratios in the reverse order to the ones specified here
|
107 |
+
that must match the decoder order. We use the decoder order as some models may only employ the decoder.
|
108 |
+
activation (str): Activation function.
|
109 |
+
activation_params (dict): Parameters to provide to the activation function.
|
110 |
+
norm (str): Normalization method.
|
111 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
112 |
+
kernel_size (int): Kernel size for the initial convolution.
|
113 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
114 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
115 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
116 |
+
causal (bool): Whether to use fully causal convolution.
|
117 |
+
pad_mode (str): Padding mode for the convolutions.
|
118 |
+
true_skip (bool): Whether to use true skip connection or a simple
|
119 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
120 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
121 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
122 |
+
For the encoder, it corresponds to the N first blocks.
|
123 |
+
mask_fn (nn.Module): Optional mask function to apply after convolution layers.
|
124 |
+
mask_position (int): Position of the mask function, with mask_position == 0 for the first convolution layer,
|
125 |
+
mask_position == 1 for the first conv block, etc.
|
126 |
+
"""
|
127 |
+
|
128 |
+
def __init__(
|
129 |
+
self,
|
130 |
+
channels: int = 1,
|
131 |
+
dimension: int = 128,
|
132 |
+
n_filters: int = 32,
|
133 |
+
n_residual_layers: int = 3,
|
134 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
135 |
+
activation: str = "ELU",
|
136 |
+
activation_params: dict = {"alpha": 1.0},
|
137 |
+
norm: str = "none",
|
138 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
139 |
+
kernel_size: int = 7,
|
140 |
+
last_kernel_size: int = 7,
|
141 |
+
residual_kernel_size: int = 3,
|
142 |
+
dilation_base: int = 2,
|
143 |
+
causal: bool = False,
|
144 |
+
pad_mode: str = "reflect",
|
145 |
+
true_skip: bool = True,
|
146 |
+
compress: int = 2,
|
147 |
+
disable_norm_outer_blocks: int = 0,
|
148 |
+
mask_fn: tp.Optional[nn.Module] = None,
|
149 |
+
mask_position: tp.Optional[int] = None,
|
150 |
+
):
|
151 |
+
super().__init__()
|
152 |
+
self.channels = channels
|
153 |
+
self.dimension = dimension
|
154 |
+
self.n_filters = n_filters
|
155 |
+
self.ratios = list(reversed(ratios))
|
156 |
+
del ratios
|
157 |
+
self.n_residual_layers = n_residual_layers
|
158 |
+
self.hop_length = int(np.prod(self.ratios))
|
159 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
160 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
161 |
+
assert (
|
162 |
+
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
|
163 |
+
), (
|
164 |
+
"Number of blocks for which to disable norm is invalid."
|
165 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
166 |
+
)
|
167 |
+
|
168 |
+
act = getattr(nn, activation)
|
169 |
+
mult = 1
|
170 |
+
model: tp.List[nn.Module] = [
|
171 |
+
StreamingConv1d(
|
172 |
+
channels,
|
173 |
+
mult * n_filters,
|
174 |
+
kernel_size,
|
175 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
176 |
+
norm_kwargs=norm_params,
|
177 |
+
causal=causal,
|
178 |
+
pad_mode=pad_mode,
|
179 |
+
)
|
180 |
+
]
|
181 |
+
if mask_fn is not None and mask_position == 0:
|
182 |
+
model += [mask_fn]
|
183 |
+
# Downsample to raw audio scale
|
184 |
+
for i, ratio in enumerate(self.ratios):
|
185 |
+
block_norm = "none" if self.disable_norm_outer_blocks >= i + 2 else norm
|
186 |
+
# Add residual layers
|
187 |
+
for j in range(n_residual_layers):
|
188 |
+
model += [
|
189 |
+
SEANetResnetBlock(
|
190 |
+
mult * n_filters,
|
191 |
+
kernel_sizes=[residual_kernel_size, 1],
|
192 |
+
dilations=[dilation_base**j, 1],
|
193 |
+
norm=block_norm,
|
194 |
+
norm_params=norm_params,
|
195 |
+
activation=activation,
|
196 |
+
activation_params=activation_params,
|
197 |
+
causal=causal,
|
198 |
+
pad_mode=pad_mode,
|
199 |
+
compress=compress,
|
200 |
+
true_skip=true_skip,
|
201 |
+
)
|
202 |
+
]
|
203 |
+
|
204 |
+
# Add downsampling layers
|
205 |
+
model += [
|
206 |
+
act(**activation_params),
|
207 |
+
StreamingConv1d(
|
208 |
+
mult * n_filters,
|
209 |
+
mult * n_filters * 2,
|
210 |
+
kernel_size=ratio * 2,
|
211 |
+
stride=ratio,
|
212 |
+
norm=block_norm,
|
213 |
+
norm_kwargs=norm_params,
|
214 |
+
causal=causal,
|
215 |
+
pad_mode=pad_mode,
|
216 |
+
),
|
217 |
+
]
|
218 |
+
mult *= 2
|
219 |
+
if mask_fn is not None and mask_position == i + 1:
|
220 |
+
model += [mask_fn]
|
221 |
+
|
222 |
+
model += [
|
223 |
+
act(**activation_params),
|
224 |
+
StreamingConv1d(
|
225 |
+
mult * n_filters,
|
226 |
+
dimension,
|
227 |
+
last_kernel_size,
|
228 |
+
norm=(
|
229 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
230 |
+
),
|
231 |
+
norm_kwargs=norm_params,
|
232 |
+
causal=causal,
|
233 |
+
pad_mode=pad_mode,
|
234 |
+
),
|
235 |
+
]
|
236 |
+
|
237 |
+
self.model = nn.Sequential(*model)
|
238 |
+
|
239 |
+
@torch_compile_lazy
|
240 |
+
def forward(self, x):
|
241 |
+
return self.model(x)
|
242 |
+
|
243 |
+
|
244 |
+
class SEANetDecoder(StreamingContainer):
|
245 |
+
"""SEANet decoder.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
channels (int): Audio channels.
|
249 |
+
dimension (int): Intermediate representation dimension.
|
250 |
+
n_filters (int): Base width for the model.
|
251 |
+
n_residual_layers (int): nb of residual layers.
|
252 |
+
ratios (Sequence[int]): kernel size and stride ratios.
|
253 |
+
activation (str): Activation function.
|
254 |
+
activation_params (dict): Parameters to provide to the activation function.
|
255 |
+
final_activation (str): Final activation function after all convolutions.
|
256 |
+
final_activation_params (dict): Parameters to provide to the activation function.
|
257 |
+
norm (str): Normalization method.
|
258 |
+
norm_params (dict): Parameters to provide to the underlying normalization used along with the convolution.
|
259 |
+
kernel_size (int): Kernel size for the initial convolution.
|
260 |
+
last_kernel_size (int): Kernel size for the initial convolution.
|
261 |
+
residual_kernel_size (int): Kernel size for the residual layers.
|
262 |
+
dilation_base (int): How much to increase the dilation with each layer.
|
263 |
+
causal (bool): Whether to use fully causal convolution.
|
264 |
+
pad_mode (str): Padding mode for the convolutions.
|
265 |
+
true_skip (bool): Whether to use true skip connection or a simple.
|
266 |
+
(streamable) convolution as the skip connection in the residual network blocks.
|
267 |
+
compress (int): Reduced dimensionality in residual branches (from Demucs v3).
|
268 |
+
disable_norm_outer_blocks (int): Number of blocks for which we don't apply norm.
|
269 |
+
For the decoder, it corresponds to the N last blocks.
|
270 |
+
trim_right_ratio (float): Ratio for trimming at the right of the transposed convolution under the causal setup.
|
271 |
+
If equal to 1.0, it means that all the trimming is done at the right.
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(
|
275 |
+
self,
|
276 |
+
channels: int = 1,
|
277 |
+
dimension: int = 128,
|
278 |
+
n_filters: int = 32,
|
279 |
+
n_residual_layers: int = 3,
|
280 |
+
ratios: tp.List[int] = [8, 5, 4, 2],
|
281 |
+
activation: str = "ELU",
|
282 |
+
activation_params: dict = {"alpha": 1.0},
|
283 |
+
final_activation: tp.Optional[str] = None,
|
284 |
+
final_activation_params: tp.Optional[dict] = None,
|
285 |
+
norm: str = "none",
|
286 |
+
norm_params: tp.Dict[str, tp.Any] = {},
|
287 |
+
kernel_size: int = 7,
|
288 |
+
last_kernel_size: int = 7,
|
289 |
+
residual_kernel_size: int = 3,
|
290 |
+
dilation_base: int = 2,
|
291 |
+
causal: bool = False,
|
292 |
+
pad_mode: str = "reflect",
|
293 |
+
true_skip: bool = True,
|
294 |
+
compress: int = 2,
|
295 |
+
disable_norm_outer_blocks: int = 0,
|
296 |
+
trim_right_ratio: float = 1.0,
|
297 |
+
):
|
298 |
+
super().__init__()
|
299 |
+
self.dimension = dimension
|
300 |
+
self.channels = channels
|
301 |
+
self.n_filters = n_filters
|
302 |
+
self.ratios = ratios
|
303 |
+
del ratios
|
304 |
+
self.n_residual_layers = n_residual_layers
|
305 |
+
self.hop_length = int(np.prod(self.ratios))
|
306 |
+
self.n_blocks = len(self.ratios) + 2 # first and last conv + residual blocks
|
307 |
+
self.disable_norm_outer_blocks = disable_norm_outer_blocks
|
308 |
+
assert (
|
309 |
+
self.disable_norm_outer_blocks >= 0 and self.disable_norm_outer_blocks <= self.n_blocks
|
310 |
+
), (
|
311 |
+
"Number of blocks for which to disable norm is invalid."
|
312 |
+
"It should be lower or equal to the actual number of blocks in the network and greater or equal to 0."
|
313 |
+
)
|
314 |
+
|
315 |
+
act = getattr(nn, activation)
|
316 |
+
mult = int(2 ** len(self.ratios))
|
317 |
+
model: tp.List[nn.Module] = [
|
318 |
+
StreamingConv1d(
|
319 |
+
dimension,
|
320 |
+
mult * n_filters,
|
321 |
+
kernel_size,
|
322 |
+
norm=(
|
323 |
+
"none" if self.disable_norm_outer_blocks == self.n_blocks else norm
|
324 |
+
),
|
325 |
+
norm_kwargs=norm_params,
|
326 |
+
causal=causal,
|
327 |
+
pad_mode=pad_mode,
|
328 |
+
)
|
329 |
+
]
|
330 |
+
|
331 |
+
# Upsample to raw audio scale
|
332 |
+
for i, ratio in enumerate(self.ratios):
|
333 |
+
block_norm = (
|
334 |
+
"none"
|
335 |
+
if self.disable_norm_outer_blocks >= self.n_blocks - (i + 1)
|
336 |
+
else norm
|
337 |
+
)
|
338 |
+
# Add upsampling layers
|
339 |
+
model += [
|
340 |
+
act(**activation_params),
|
341 |
+
StreamingConvTranspose1d(
|
342 |
+
mult * n_filters,
|
343 |
+
mult * n_filters // 2,
|
344 |
+
kernel_size=ratio * 2,
|
345 |
+
stride=ratio,
|
346 |
+
norm=block_norm,
|
347 |
+
norm_kwargs=norm_params,
|
348 |
+
causal=causal,
|
349 |
+
trim_right_ratio=trim_right_ratio,
|
350 |
+
),
|
351 |
+
]
|
352 |
+
# Add residual layers
|
353 |
+
for j in range(n_residual_layers):
|
354 |
+
model += [
|
355 |
+
SEANetResnetBlock(
|
356 |
+
mult * n_filters // 2,
|
357 |
+
kernel_sizes=[residual_kernel_size, 1],
|
358 |
+
dilations=[dilation_base**j, 1],
|
359 |
+
activation=activation,
|
360 |
+
activation_params=activation_params,
|
361 |
+
norm=block_norm,
|
362 |
+
norm_params=norm_params,
|
363 |
+
causal=causal,
|
364 |
+
pad_mode=pad_mode,
|
365 |
+
compress=compress,
|
366 |
+
true_skip=true_skip,
|
367 |
+
)
|
368 |
+
]
|
369 |
+
|
370 |
+
mult //= 2
|
371 |
+
|
372 |
+
# Add final layers
|
373 |
+
model += [
|
374 |
+
act(**activation_params),
|
375 |
+
StreamingConv1d(
|
376 |
+
n_filters,
|
377 |
+
channels,
|
378 |
+
last_kernel_size,
|
379 |
+
norm="none" if self.disable_norm_outer_blocks >= 1 else norm,
|
380 |
+
norm_kwargs=norm_params,
|
381 |
+
causal=causal,
|
382 |
+
pad_mode=pad_mode,
|
383 |
+
),
|
384 |
+
]
|
385 |
+
# Add optional final activation to decoder (eg. tanh)
|
386 |
+
if final_activation is not None:
|
387 |
+
final_act = getattr(nn, final_activation)
|
388 |
+
final_activation_params = final_activation_params or {}
|
389 |
+
model += [final_act(**final_activation_params)]
|
390 |
+
self.model = nn.Sequential(*model)
|
391 |
+
|
392 |
+
@torch_compile_lazy
|
393 |
+
def forward(self, z):
|
394 |
+
y = self.model(z)
|
395 |
+
return y
|
moshi/modules/streaming.py
ADDED
@@ -0,0 +1,363 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
"""
|
12 |
+
Streaming module API that should be implemented by all Streaming components,
|
13 |
+
"""
|
14 |
+
|
15 |
+
import abc
|
16 |
+
from contextlib import contextmanager
|
17 |
+
from dataclasses import dataclass
|
18 |
+
import itertools
|
19 |
+
import math
|
20 |
+
import typing as tp
|
21 |
+
from torch import nn
|
22 |
+
import torch
|
23 |
+
|
24 |
+
|
25 |
+
class Resetable(tp.Protocol):
|
26 |
+
def reset(self) -> None:
|
27 |
+
pass
|
28 |
+
|
29 |
+
|
30 |
+
State = tp.TypeVar("State", bound=Resetable)
|
31 |
+
|
32 |
+
|
33 |
+
class StreamingModule(abc.ABC, nn.Module, tp.Generic[State]):
|
34 |
+
"""Common API for streaming components.
|
35 |
+
|
36 |
+
Each streaming component has a streaming state, which is just a dict[str, Tensor].
|
37 |
+
By convention, the first dim of each tensor must be the batch size.
|
38 |
+
Don't use dots in the key names, as this would clash with submodules
|
39 |
+
(like in state_dict).
|
40 |
+
|
41 |
+
If `self._is_streaming` is True, the component should use and remember
|
42 |
+
the proper state inside `self._streaming_state`.
|
43 |
+
|
44 |
+
To set a streaming component in streaming state, use
|
45 |
+
|
46 |
+
with module.streaming():
|
47 |
+
...
|
48 |
+
|
49 |
+
This will automatically reset the streaming state when exiting the context manager.
|
50 |
+
This also automatically propagates to all streaming children module.
|
51 |
+
|
52 |
+
Some module might also implement the `StreamingModule.flush` method, although
|
53 |
+
this one is trickier, as all parents module must be StreamingModule and implement
|
54 |
+
it as well for it to work properly. See `StreamingSequential` after.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self) -> None:
|
58 |
+
super().__init__()
|
59 |
+
self._streaming_state: State | None = None
|
60 |
+
self._streaming_propagate: bool = True
|
61 |
+
|
62 |
+
@property
|
63 |
+
def is_streaming(self):
|
64 |
+
return self._streaming_state is not None
|
65 |
+
|
66 |
+
def set_streaming_propagate(self, streaming_propagate: bool):
|
67 |
+
self._streaming_propagate = streaming_propagate
|
68 |
+
|
69 |
+
def _apply_named_streaming(self, fn: tp.Any):
|
70 |
+
def _handle_module(prefix: str, module: nn.Module, recurse: bool = True):
|
71 |
+
propagate = True
|
72 |
+
if isinstance(module, StreamingModule):
|
73 |
+
if module._streaming_propagate:
|
74 |
+
fn(prefix, module)
|
75 |
+
else:
|
76 |
+
propagate = False
|
77 |
+
if not recurse:
|
78 |
+
return
|
79 |
+
if propagate:
|
80 |
+
for name, child in module.named_children():
|
81 |
+
_handle_module(prefix + "." + name, child)
|
82 |
+
|
83 |
+
_handle_module("", self, recurse=False)
|
84 |
+
for name, child in self.named_children():
|
85 |
+
_handle_module(name, child)
|
86 |
+
|
87 |
+
def _start_streaming(self, batch_size: int):
|
88 |
+
def _start_streaming(name: str, module: StreamingModule):
|
89 |
+
module._streaming_state = module._init_streaming_state(batch_size)
|
90 |
+
|
91 |
+
self._apply_named_streaming(_start_streaming)
|
92 |
+
|
93 |
+
def _stop_streaming(self):
|
94 |
+
def _stop_streaming(name: str, module: StreamingModule):
|
95 |
+
module._streaming_state = None
|
96 |
+
|
97 |
+
self._apply_named_streaming(_stop_streaming)
|
98 |
+
|
99 |
+
@abc.abstractmethod
|
100 |
+
def _init_streaming_state(self, batch_size: int) -> State: ...
|
101 |
+
|
102 |
+
def streaming_forever(self, batch_size: int):
|
103 |
+
self._start_streaming(batch_size)
|
104 |
+
|
105 |
+
@contextmanager
|
106 |
+
def streaming(self, batch_size: int):
|
107 |
+
"""Context manager to enter streaming mode. Reset streaming state on exit."""
|
108 |
+
|
109 |
+
self._start_streaming(batch_size)
|
110 |
+
try:
|
111 |
+
yield
|
112 |
+
finally:
|
113 |
+
self._stop_streaming()
|
114 |
+
|
115 |
+
def reset_streaming(self):
|
116 |
+
"""Reset the streaming state."""
|
117 |
+
|
118 |
+
def _reset(name: str, module: StreamingModule):
|
119 |
+
state = module._streaming_state
|
120 |
+
if state is None:
|
121 |
+
raise ValueError(
|
122 |
+
f"Trying to reset streaming, but {name} wasn't streaming."
|
123 |
+
)
|
124 |
+
state.reset()
|
125 |
+
|
126 |
+
self._apply_named_streaming(_reset)
|
127 |
+
|
128 |
+
def get_streaming_state(self) -> dict[str, tp.Any]:
|
129 |
+
"""Return the complete streaming state, including that of sub-modules."""
|
130 |
+
state: dict[str, tp.Any] = {}
|
131 |
+
|
132 |
+
def _add(name: str, module: StreamingModule):
|
133 |
+
state[name] = module._streaming_state
|
134 |
+
|
135 |
+
self._apply_named_streaming(_add)
|
136 |
+
return state
|
137 |
+
|
138 |
+
def set_streaming_state(self, state: dict[str, tp.Any]):
|
139 |
+
"""Set the streaming state, including that of sub-modules."""
|
140 |
+
state = dict(state)
|
141 |
+
|
142 |
+
def _set(name: str, module: StreamingModule):
|
143 |
+
if name in state:
|
144 |
+
module._streaming_state = state[name]
|
145 |
+
state.pop(name)
|
146 |
+
else:
|
147 |
+
raise RuntimeError(f"Expected to find a streaming state for {name}.")
|
148 |
+
|
149 |
+
self._apply_named_streaming(_set)
|
150 |
+
if state:
|
151 |
+
raise RuntimeError(f"Some states were not consumed: {list(state.keys())}")
|
152 |
+
|
153 |
+
|
154 |
+
@dataclass
|
155 |
+
class _NullState:
|
156 |
+
pass
|
157 |
+
|
158 |
+
def reset(self) -> None:
|
159 |
+
pass
|
160 |
+
|
161 |
+
|
162 |
+
class StreamingContainer(StreamingModule[_NullState]):
|
163 |
+
def _init_streaming_state(self, batch_size: int) -> _NullState:
|
164 |
+
return _NullState()
|
165 |
+
|
166 |
+
|
167 |
+
@dataclass
|
168 |
+
class _StreamingAddState:
|
169 |
+
previous_x: torch.Tensor | None = None
|
170 |
+
previous_y: torch.Tensor | None = None
|
171 |
+
|
172 |
+
def reset(self):
|
173 |
+
self.previous_x = None
|
174 |
+
self.previous_y = None
|
175 |
+
|
176 |
+
|
177 |
+
class StreamingAdd(StreamingModule[_StreamingAddState]):
|
178 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingAddState:
|
179 |
+
return _StreamingAddState()
|
180 |
+
|
181 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor):
|
182 |
+
if self._streaming_state is None:
|
183 |
+
return x + y
|
184 |
+
else:
|
185 |
+
prev_x = self._streaming_state.previous_x
|
186 |
+
prev_y = self._streaming_state.previous_y
|
187 |
+
if prev_x is not None:
|
188 |
+
x = torch.cat([prev_x, x], dim=-1)
|
189 |
+
if prev_y is not None:
|
190 |
+
y = torch.cat([prev_y, y], dim=-1)
|
191 |
+
m_l = min(x.shape[-1], y.shape[-1])
|
192 |
+
self._streaming_state.previous_x = x[..., m_l:]
|
193 |
+
self._streaming_state.previous_y = y[..., m_l:]
|
194 |
+
return x[..., :m_l] + y[..., :m_l]
|
195 |
+
|
196 |
+
|
197 |
+
@dataclass
|
198 |
+
class _StreamingConvState:
|
199 |
+
previous: torch.Tensor | None = None
|
200 |
+
|
201 |
+
def reset(self):
|
202 |
+
self.previous = None
|
203 |
+
|
204 |
+
|
205 |
+
class RawStreamingConv1d(nn.Conv1d, StreamingModule[_StreamingConvState]):
|
206 |
+
def __init__(self, *args, **kwargs):
|
207 |
+
super().__init__(*args, **kwargs)
|
208 |
+
assert self.padding[0] == 0, "Padding should be handled outside."
|
209 |
+
assert (
|
210 |
+
self.stride[0] <= self.kernel_size[0]
|
211 |
+
), "stride must be less than kernel_size."
|
212 |
+
|
213 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvState:
|
214 |
+
return _StreamingConvState()
|
215 |
+
|
216 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
217 |
+
stride = self.stride[0]
|
218 |
+
# Effective kernel size accounting for dilation.
|
219 |
+
kernel = (self.kernel_size[0] - 1) * self.dilation[0] + 1
|
220 |
+
if self._streaming_state is None:
|
221 |
+
return super().forward(input)
|
222 |
+
else:
|
223 |
+
# Due to the potential overlap, we might have some cache of the previous time steps.
|
224 |
+
previous = self._streaming_state.previous
|
225 |
+
if previous is not None:
|
226 |
+
input = torch.cat([previous, input], dim=-1)
|
227 |
+
B, C, T = input.shape
|
228 |
+
# We now compute the number of full convolution frames, i.e. the frames
|
229 |
+
# that are ready to be computed.
|
230 |
+
num_frames = max(0, int(math.floor((T - kernel) / stride) + 1))
|
231 |
+
offset = num_frames * stride
|
232 |
+
# We will compute `num_frames` outputs, and we are advancing by `stride`
|
233 |
+
# for each of the frame, so we know the data before `stride * num_frames`
|
234 |
+
# will never be used again.
|
235 |
+
self._streaming_state.previous = input[..., offset:]
|
236 |
+
if num_frames > 0:
|
237 |
+
input_length = (num_frames - 1) * stride + kernel
|
238 |
+
out = super().forward(input[..., :input_length])
|
239 |
+
else:
|
240 |
+
# Not enough data as this point to output some new frames.
|
241 |
+
out = torch.empty(
|
242 |
+
B, self.out_channels, 0, device=input.device, dtype=input.dtype
|
243 |
+
)
|
244 |
+
return out
|
245 |
+
|
246 |
+
|
247 |
+
@dataclass
|
248 |
+
class _StreamingConvTrState:
|
249 |
+
partial: torch.Tensor | None = None
|
250 |
+
|
251 |
+
def reset(self):
|
252 |
+
self.partial = None
|
253 |
+
|
254 |
+
|
255 |
+
class RawStreamingConvTranspose1d(
|
256 |
+
nn.ConvTranspose1d, StreamingModule[_StreamingConvTrState]
|
257 |
+
):
|
258 |
+
def __init__(self, *args, **kwargs):
|
259 |
+
super().__init__(*args, **kwargs)
|
260 |
+
assert self.padding[0] == 0, "Padding should be handled outside."
|
261 |
+
assert self.dilation[0] == 1, "No dilation for now"
|
262 |
+
assert (
|
263 |
+
self.stride[0] <= self.kernel_size[0]
|
264 |
+
), "stride must be less than kernel_size."
|
265 |
+
assert self.output_padding[0] == 0, "Output padding not supported."
|
266 |
+
|
267 |
+
def _init_streaming_state(self, batch_size: int) -> _StreamingConvTrState:
|
268 |
+
return _StreamingConvTrState()
|
269 |
+
|
270 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor: # type: ignore
|
271 |
+
B, C, T = x.shape
|
272 |
+
stride = self.stride[0]
|
273 |
+
kernel = self.kernel_size[0]
|
274 |
+
if self._streaming_state is None:
|
275 |
+
return super().forward(x)
|
276 |
+
else:
|
277 |
+
if T == 0:
|
278 |
+
return torch.empty(
|
279 |
+
B, self.out_channels, 0, device=x.device, dtype=x.dtype
|
280 |
+
)
|
281 |
+
out = super().forward(x)
|
282 |
+
OT = out.shape[-1]
|
283 |
+
partial = self._streaming_state.partial
|
284 |
+
if partial is not None:
|
285 |
+
# Due to the potential overlap, the rightmost output of the conv transpose is not
|
286 |
+
# ready to be output, as it will receive contributions from the next input frames.
|
287 |
+
# Here we recover those `partial` output frames. We know that the first time step
|
288 |
+
# of the `partial` tensor corresponds to the first time step of `out` as anything
|
289 |
+
# coming before the first time step of `out` would have been already flushed.
|
290 |
+
PT = partial.shape[-1]
|
291 |
+
if self.bias is not None:
|
292 |
+
out[..., :PT] += partial - self.bias[:, None]
|
293 |
+
else:
|
294 |
+
out[..., :PT] += partial
|
295 |
+
# The input is T, the output is S * (T - 1) + K.
|
296 |
+
# The offset of the left of the next frame will be S * T
|
297 |
+
# so everything between 0 and S * T is ready to be output, and we need
|
298 |
+
# to keep in the internal state everything beyond that, i.e. S (T - 1) + K - S T = K - S
|
299 |
+
invalid_steps = kernel - stride
|
300 |
+
partial = out[..., OT - invalid_steps :]
|
301 |
+
out = out[..., : OT - invalid_steps]
|
302 |
+
self._streaming_state.partial = partial
|
303 |
+
return out
|
304 |
+
|
305 |
+
|
306 |
+
def test():
|
307 |
+
torch.manual_seed(1234)
|
308 |
+
device = "cpu"
|
309 |
+
if torch.cuda.is_available():
|
310 |
+
# Avoid the cuda optimizations that would take place on single precision
|
311 |
+
# floats for convolutions.
|
312 |
+
torch.backends.cudnn.enabled = True
|
313 |
+
torch.backends.cudnn.benchmark = False
|
314 |
+
torch.backends.cudnn.deterministic = True
|
315 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
316 |
+
torch.backends.cudnn.allow_tf32 = False
|
317 |
+
device = "cuda:0"
|
318 |
+
|
319 |
+
kernel_sizes = [1, 3, 4, 8, 15, 16]
|
320 |
+
strides = [1, 2, 3, 4, 5, 6, 7, 8, 9]
|
321 |
+
chin = 6
|
322 |
+
chout = 12
|
323 |
+
|
324 |
+
for kernel, stride in itertools.product(kernel_sizes, strides):
|
325 |
+
if stride > kernel:
|
326 |
+
continue
|
327 |
+
conv = RawStreamingConv1d(chin, chout, kernel, stride).to(device)
|
328 |
+
convtr = RawStreamingConvTranspose1d(chout, chin, kernel, stride).to(device)
|
329 |
+
|
330 |
+
for length in [4, 8, 32, 54, 65, 128, 1043]:
|
331 |
+
print(f"ksize {kernel} strides {stride} len {length}")
|
332 |
+
if length < kernel:
|
333 |
+
continue
|
334 |
+
batch_size = 3
|
335 |
+
x = torch.randn(batch_size, chin, length).to(device)
|
336 |
+
y = conv(x)
|
337 |
+
z = convtr(y)
|
338 |
+
for chunk_size in [1, 3, 5, 8]:
|
339 |
+
ys = []
|
340 |
+
zs = []
|
341 |
+
with conv.streaming(batch_size), convtr.streaming(batch_size):
|
342 |
+
for offset in range(0, length, chunk_size):
|
343 |
+
chunk = x[..., offset : offset + chunk_size]
|
344 |
+
ys.append(conv(chunk))
|
345 |
+
zs.append(convtr(ys[-1]))
|
346 |
+
y_stream = torch.cat(ys, dim=-1)
|
347 |
+
z_stream = torch.cat(zs, dim=-1)
|
348 |
+
y = y[..., : y_stream.shape[-1]]
|
349 |
+
z = z[..., : z_stream.shape[-1]]
|
350 |
+
assert y.shape == y_stream.shape, (y.shape, y_stream.shape)
|
351 |
+
delta = (y_stream - y).norm() / y.norm()
|
352 |
+
assert delta <= 1e-6, delta
|
353 |
+
num_frames = int((length - kernel) / stride) + 1
|
354 |
+
assert num_frames == y_stream.shape[-1]
|
355 |
+
|
356 |
+
assert z.shape == z_stream.shape, (z.shape, z_stream.shape)
|
357 |
+
delta = (z_stream - z).norm() / z.norm()
|
358 |
+
assert delta <= 1e-6, (delta, (z_stream - z).abs().mean(dim=(0, 1)))
|
359 |
+
|
360 |
+
|
361 |
+
if __name__ == "__main__":
|
362 |
+
with torch.no_grad():
|
363 |
+
test()
|
moshi/modules/transformer.py
ADDED
@@ -0,0 +1,750 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
"""
|
6 |
+
Transformer model, with streaming support, + CUDA Graphable.
|
7 |
+
Optimized for inference.
|
8 |
+
|
9 |
+
See `StreamingTransformer` for more information.
|
10 |
+
"""
|
11 |
+
|
12 |
+
from contextlib import ExitStack
|
13 |
+
from dataclasses import dataclass
|
14 |
+
import typing as tp
|
15 |
+
|
16 |
+
from einops import rearrange
|
17 |
+
import torch
|
18 |
+
import torch.nn as nn
|
19 |
+
from torch.nn import functional as F
|
20 |
+
|
21 |
+
from ..utils.compile import no_compile
|
22 |
+
from .gating import make_gating
|
23 |
+
from .rope import RotaryEmbedding
|
24 |
+
from .streaming import StreamingModule, StreamingContainer
|
25 |
+
|
26 |
+
|
27 |
+
class LayerNormF32(nn.LayerNorm):
|
28 |
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
29 |
+
x_f32 = input.float()
|
30 |
+
out_f32 = super().forward(x_f32)
|
31 |
+
return out_f32.to(input.dtype)
|
32 |
+
|
33 |
+
|
34 |
+
def _rms_norm(
|
35 |
+
x: torch.Tensor,
|
36 |
+
alpha: torch.Tensor,
|
37 |
+
dtype: tp.Optional[torch.dtype],
|
38 |
+
eps: float,
|
39 |
+
):
|
40 |
+
assert x.dim() == 3, f"RMSNorm expects 3D inputs but got {x.shape}"
|
41 |
+
x_dtype = x.dtype
|
42 |
+
if dtype is not None:
|
43 |
+
x = x.to(dtype)
|
44 |
+
var = eps + torch.mean(x**2, dim=2, keepdim=True)
|
45 |
+
y = (x * (alpha.to(var) * torch.rsqrt(var))).to(x_dtype)
|
46 |
+
return y
|
47 |
+
|
48 |
+
|
49 |
+
class RMSNorm(nn.Module):
|
50 |
+
def __init__(
|
51 |
+
self,
|
52 |
+
dim: int,
|
53 |
+
eps: float = 1e-5,
|
54 |
+
dtype: tp.Optional[torch.dtype] = None,
|
55 |
+
device=None,
|
56 |
+
):
|
57 |
+
super().__init__()
|
58 |
+
self.eps = eps
|
59 |
+
self.dtype = dtype
|
60 |
+
self.alpha = nn.Parameter(
|
61 |
+
torch.full((1, 1, dim), 1.0, requires_grad=True, device=device, dtype=dtype)
|
62 |
+
)
|
63 |
+
|
64 |
+
def forward(self, x: torch.Tensor):
|
65 |
+
return _rms_norm(x, self.alpha, self.dtype, self.eps)
|
66 |
+
|
67 |
+
|
68 |
+
class LayerScale(nn.Module):
|
69 |
+
"""Layer scale from [Touvron et al 2021] (https://arxiv.org/pdf/2103.17239.pdf).
|
70 |
+
This rescales diagonally the residual outputs close to 0, with a learnt scale.
|
71 |
+
|
72 |
+
Args:
|
73 |
+
channels (int): Number of channels.
|
74 |
+
init (float): Initial scale.
|
75 |
+
channel_last (bool): If True, expect `[*, C]` shaped tensors, otherwise, `[*, C, T]`.
|
76 |
+
device (torch.device or str, optional): Device on which to initialize the module.
|
77 |
+
dtype (torch.dtype, optional): dtype to use to initialize the module.
|
78 |
+
"""
|
79 |
+
|
80 |
+
def __init__(
|
81 |
+
self,
|
82 |
+
channels: int,
|
83 |
+
init: float = 1e-4,
|
84 |
+
channel_last: bool = True,
|
85 |
+
device=None,
|
86 |
+
dtype=None,
|
87 |
+
):
|
88 |
+
super().__init__()
|
89 |
+
self.channel_last = channel_last
|
90 |
+
self.scale = nn.Parameter(
|
91 |
+
torch.full(
|
92 |
+
(channels,), init, requires_grad=True, device=device, dtype=dtype
|
93 |
+
)
|
94 |
+
)
|
95 |
+
|
96 |
+
def forward(self, x: torch.Tensor):
|
97 |
+
if self.channel_last:
|
98 |
+
return self.scale * x
|
99 |
+
else:
|
100 |
+
return self.scale[:, None] * x
|
101 |
+
|
102 |
+
|
103 |
+
def create_norm_fn(norm_type: str, dim: int, **kwargs) -> nn.Module:
|
104 |
+
"""Create normalization module for transformer encoder layer.
|
105 |
+
|
106 |
+
Args:
|
107 |
+
norm_type (str): Normalization method.
|
108 |
+
dim (int): Dimension of the normalized layer.
|
109 |
+
**kwargs (dict): Additional parameters for normalization layer.
|
110 |
+
Returns:
|
111 |
+
nn.Module: Normalization module.
|
112 |
+
"""
|
113 |
+
if norm_type == "layer_norm":
|
114 |
+
return nn.LayerNorm(dim, eps=1e-5, **kwargs)
|
115 |
+
elif norm_type == "layer_norm_f32":
|
116 |
+
kwargs.pop("dtype", None)
|
117 |
+
return LayerNormF32(dim, eps=1e-8, **kwargs)
|
118 |
+
elif norm_type in {"rms_norm"}:
|
119 |
+
return RMSNorm(dim, eps=1e-5, **kwargs)
|
120 |
+
elif norm_type in {"rms_norm_f32"}:
|
121 |
+
kwargs.pop("dtype", None)
|
122 |
+
return RMSNorm(dim, eps=1e-8, dtype=torch.float, **kwargs)
|
123 |
+
else:
|
124 |
+
raise ValueError(f"Unknown norm type: {norm_type}")
|
125 |
+
|
126 |
+
|
127 |
+
def create_sin_embedding(
|
128 |
+
positions: torch.Tensor,
|
129 |
+
dim: int,
|
130 |
+
max_period: float = 10000,
|
131 |
+
dtype: torch.dtype = torch.float32,
|
132 |
+
) -> torch.Tensor:
|
133 |
+
"""Create sinusoidal positional embedding, with shape `[B, T, C]`.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
positions (torch.Tensor): LongTensor of positions.
|
137 |
+
dim (int): Dimension of the embedding.
|
138 |
+
max_period (float): Maximum period of the cosine/sine functions.
|
139 |
+
dtype (torch.dtype or str): dtype to use to generate the embedding.
|
140 |
+
Returns:
|
141 |
+
torch.Tensor: Sinusoidal positional embedding.
|
142 |
+
"""
|
143 |
+
# We aim for BTC format
|
144 |
+
assert dim % 2 == 0
|
145 |
+
half_dim = dim // 2
|
146 |
+
positions = positions.to(dtype)
|
147 |
+
adim = torch.arange(half_dim, device=positions.device, dtype=dtype).view(1, 1, -1)
|
148 |
+
max_period_tensor = torch.full(
|
149 |
+
[], max_period, device=positions.device, dtype=dtype
|
150 |
+
) # avoid sync point
|
151 |
+
phase = positions / (max_period_tensor ** (adim / (half_dim - 1)))
|
152 |
+
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
153 |
+
|
154 |
+
|
155 |
+
def multi_linear(
|
156 |
+
num_linear: int,
|
157 |
+
weight: torch.Tensor,
|
158 |
+
x: torch.Tensor,
|
159 |
+
offset: int,
|
160 |
+
):
|
161 |
+
"""Utility to apply a multi linear layer to the given input. A multi linear layer
|
162 |
+
applies a different set of weight for each time step.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
num_linear (int): Number of possible time steps and so number of linears.
|
166 |
+
weight (torch.Tensor): Weight tensor, with shape `[num_linear * chout, chin]`.
|
167 |
+
x (torch.Tensor): Input tensor, with shape `[B, T, C]`.
|
168 |
+
offset (int): offset for the current time step, in particular for decoding, with
|
169 |
+
time steps provided one by one.
|
170 |
+
"""
|
171 |
+
B, T, C = x.shape
|
172 |
+
ys = []
|
173 |
+
chout, chin = weight.shape
|
174 |
+
weight = weight.view(num_linear, -1, chin)
|
175 |
+
for t in range(T):
|
176 |
+
y = F.linear(x[:, t], weight[t + offset])
|
177 |
+
ys.append(y)
|
178 |
+
out = torch.stack(ys, 1)
|
179 |
+
return out
|
180 |
+
|
181 |
+
|
182 |
+
def set_attention_context(model: nn.Module, context: tp.Optional[int] = None) -> None:
|
183 |
+
"""Deactivates or changes the context span (in time steps) in a model.
|
184 |
+
Args:
|
185 |
+
model (nn.Module): model over which to look for attentions.
|
186 |
+
context (int or None): new temporary context value.
|
187 |
+
|
188 |
+
..Note:: this is not a context manager but a plain function changing the context forever.
|
189 |
+
Initially, it was a context manager, but that led to interesting bugs when using
|
190 |
+
activation checkpointing, with the context being inconsistent between the forward
|
191 |
+
and backward.
|
192 |
+
"""
|
193 |
+
for module in model.modules():
|
194 |
+
if isinstance(module, StreamingMultiheadAttention):
|
195 |
+
module.context = context
|
196 |
+
|
197 |
+
|
198 |
+
class KVCacheResult(tp.NamedTuple):
|
199 |
+
keys: torch.Tensor
|
200 |
+
values: torch.Tensor
|
201 |
+
positions: torch.Tensor
|
202 |
+
|
203 |
+
@staticmethod
|
204 |
+
def from_kv(keys: torch.Tensor, values: torch.Tensor) -> "KVCacheResult":
|
205 |
+
B, H, T, D = keys.shape
|
206 |
+
assert tuple(values.shape[:-1]) == (B, H, T)
|
207 |
+
positions = torch.arange(T, device=keys.device, dtype=torch.long)
|
208 |
+
return KVCacheResult(keys, values, positions)
|
209 |
+
|
210 |
+
|
211 |
+
class RingKVCache:
|
212 |
+
"""Efficient streaming KVCache to be compatible with Cuda Graph.
|
213 |
+
|
214 |
+
Args:
|
215 |
+
batch_size (int): Batch size.
|
216 |
+
num_heads (int): Number of heads in the attention.
|
217 |
+
dim_per_head (int): Dimension per head.
|
218 |
+
device (torch.device): Device on which to initialize the cache.
|
219 |
+
dtype (torch.dtype): dtype to use for the cache.
|
220 |
+
"""
|
221 |
+
|
222 |
+
def __init__(
|
223 |
+
self,
|
224 |
+
batch_size: int,
|
225 |
+
num_heads: int,
|
226 |
+
dim_per_head: int,
|
227 |
+
capacity: int,
|
228 |
+
device: torch.device = torch.device("cuda"),
|
229 |
+
dtype: torch.dtype = torch.bfloat16,
|
230 |
+
):
|
231 |
+
self.capacity = capacity
|
232 |
+
self.cache = torch.zeros(
|
233 |
+
(2, batch_size, num_heads, capacity, dim_per_head),
|
234 |
+
device=device,
|
235 |
+
dtype=dtype,
|
236 |
+
)
|
237 |
+
self.end_offset = torch.zeros(1, device=device, dtype=torch.long)
|
238 |
+
|
239 |
+
def reset(self):
|
240 |
+
self.end_offset.zero_()
|
241 |
+
|
242 |
+
def complete(self, k: torch.Tensor, v: torch.Tensor) -> KVCacheResult:
|
243 |
+
assert k.shape[:-1] == v.shape[:-1], (k.shape, v.shape)
|
244 |
+
B, H, T, D = k.shape
|
245 |
+
indexes = torch.arange(T, device=self.end_offset.device, dtype=self.end_offset.dtype) + self.end_offset
|
246 |
+
indexes = indexes % self.capacity
|
247 |
+
self.cache[0].index_copy_(2, indexes, k)
|
248 |
+
self.cache[1].index_copy_(2, indexes, v)
|
249 |
+
self.end_offset.add_(T)
|
250 |
+
|
251 |
+
keys = self.cache[0]
|
252 |
+
values = self.cache[1]
|
253 |
+
|
254 |
+
indexes = torch.arange(
|
255 |
+
self.capacity, device=self.end_offset.device, dtype=torch.long
|
256 |
+
)
|
257 |
+
invalid = indexes >= self.end_offset
|
258 |
+
|
259 |
+
end_index = self.end_offset % self.capacity
|
260 |
+
delta = indexes - end_index
|
261 |
+
|
262 |
+
# If last key is for step S, and capacity is C, last key was written at index S % C.
|
263 |
+
# then end_offset = S + 1, and end_index = (S + 1) % C.
|
264 |
+
# Then for index = (S % C), delta = -1, and the next code gives us:
|
265 |
+
# position(index) = (S + 1) - 1 = S, all good.
|
266 |
+
# Now the time step at end_offset is actually the oldest in the KVCache, e.g., its
|
267 |
+
# position should be (S - self.capacity + 1).
|
268 |
+
# The following code gives us:
|
269 |
+
# position(index + 1) = S + 1 + 0 - self.capacity.
|
270 |
+
|
271 |
+
positions = torch.where(
|
272 |
+
delta <= 0,
|
273 |
+
self.end_offset + delta,
|
274 |
+
self.end_offset + delta - self.capacity,
|
275 |
+
)
|
276 |
+
positions = torch.where(invalid, torch.full_like(positions, -1), positions)
|
277 |
+
|
278 |
+
return KVCacheResult(keys, values, positions)
|
279 |
+
|
280 |
+
|
281 |
+
@dataclass
|
282 |
+
class _MHAState:
|
283 |
+
kv_cache: RingKVCache
|
284 |
+
offset: torch.Tensor
|
285 |
+
offset_cpu: int
|
286 |
+
|
287 |
+
def reset(self):
|
288 |
+
self.kv_cache.reset()
|
289 |
+
self.offset.zero_()
|
290 |
+
self.offset_cpu = 0
|
291 |
+
|
292 |
+
|
293 |
+
class StreamingMultiheadAttention(StreamingModule[_MHAState]):
|
294 |
+
"""Similar to `nn.MultiheadAttention` but with support for streaming, causal evaluation.
|
295 |
+
|
296 |
+
Args:
|
297 |
+
embed_dim (int): Dimension to project to.
|
298 |
+
num_heads (int): Number of heads.
|
299 |
+
causal (bool): Causal mask applied automatically.
|
300 |
+
context (int, optional): Number of time steps the attention can access to.
|
301 |
+
When causal, can access `context` time steps into the past, and when non causal,
|
302 |
+
can access `context // 2` steps in the past, and the same in the future.
|
303 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
304 |
+
weights_per_step (int): use different weights per time step. If non zero, should correspond to the
|
305 |
+
number of possible time steps.
|
306 |
+
device (torch.device, optional): Device on which to initialize.
|
307 |
+
dtype (torch.dtype, optional): dtype to use.
|
308 |
+
"""
|
309 |
+
|
310 |
+
_fsdp_final = True
|
311 |
+
|
312 |
+
def __init__(
|
313 |
+
self,
|
314 |
+
embed_dim: int,
|
315 |
+
num_heads: int,
|
316 |
+
causal: bool = False,
|
317 |
+
context: tp.Optional[int] = None,
|
318 |
+
rope: tp.Optional[RotaryEmbedding] = None,
|
319 |
+
weights_per_step: int = 0,
|
320 |
+
device=None,
|
321 |
+
dtype=None,
|
322 |
+
):
|
323 |
+
super().__init__()
|
324 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
325 |
+
|
326 |
+
self.embed_dim = embed_dim
|
327 |
+
self.causal = causal
|
328 |
+
self.context = context
|
329 |
+
self.rope = rope
|
330 |
+
self.num_heads = num_heads
|
331 |
+
|
332 |
+
out_dim = embed_dim
|
333 |
+
out_dim = 3 * embed_dim
|
334 |
+
mult = 1
|
335 |
+
self.weights_per_step = weights_per_step
|
336 |
+
if weights_per_step:
|
337 |
+
mult = weights_per_step
|
338 |
+
in_proj = nn.Linear(embed_dim, mult * out_dim, bias=False, **factory_kwargs)
|
339 |
+
# We try to follow the default PyTorch MHA convention, to easily compare results.
|
340 |
+
self.in_proj_weight = in_proj.weight
|
341 |
+
self.in_proj_bias = in_proj.bias
|
342 |
+
self.out_proj = nn.Linear(
|
343 |
+
embed_dim, mult * embed_dim, bias=False, **factory_kwargs
|
344 |
+
)
|
345 |
+
|
346 |
+
def _init_streaming_state(self, batch_size: int) -> _MHAState:
|
347 |
+
if self.context is None:
|
348 |
+
if self.weights_per_step:
|
349 |
+
capacity = self.weights_per_step
|
350 |
+
else:
|
351 |
+
raise RuntimeError(
|
352 |
+
"Cannot create a streaming KVCache without a context to estimate capacity."
|
353 |
+
)
|
354 |
+
else:
|
355 |
+
capacity = self.context
|
356 |
+
device = self.in_proj_weight.device
|
357 |
+
# TODO: the following estimation will not work great with FSDP.
|
358 |
+
dtype = self.in_proj_weight.dtype
|
359 |
+
dim_per_head = self.embed_dim // self.num_heads
|
360 |
+
kv_cache = RingKVCache(
|
361 |
+
batch_size, self.num_heads, dim_per_head, capacity, device, dtype
|
362 |
+
)
|
363 |
+
return _MHAState(
|
364 |
+
kv_cache,
|
365 |
+
offset=torch.zeros(1, device=device, dtype=torch.long),
|
366 |
+
offset_cpu=0,
|
367 |
+
)
|
368 |
+
|
369 |
+
def _complete_kv(self, k, v) -> KVCacheResult:
|
370 |
+
state = self._streaming_state
|
371 |
+
if state is None:
|
372 |
+
return KVCacheResult.from_kv(k, v)
|
373 |
+
else:
|
374 |
+
return state.kv_cache.complete(k, v)
|
375 |
+
|
376 |
+
def forward(self, query: torch.Tensor, key: torch.Tensor, value: torch.Tensor):
|
377 |
+
state = self._streaming_state
|
378 |
+
T = query.shape[1]
|
379 |
+
|
380 |
+
if state is None:
|
381 |
+
offset = torch.zeros(1, device=query.device, dtype=torch.long)
|
382 |
+
offset_cpu = 0
|
383 |
+
else:
|
384 |
+
assert self.causal, "Streaming only available for causal"
|
385 |
+
offset = state.offset
|
386 |
+
offset_cpu = state.offset_cpu
|
387 |
+
|
388 |
+
if self.weights_per_step:
|
389 |
+
projected = multi_linear(
|
390 |
+
self.weights_per_step, self.in_proj_weight, query, offset_cpu
|
391 |
+
)
|
392 |
+
else:
|
393 |
+
projected = nn.functional.linear(query, self.in_proj_weight)
|
394 |
+
q, k, v = rearrange(
|
395 |
+
projected, "b t (p h d) -> p b h t d", p=3, h=self.num_heads
|
396 |
+
)
|
397 |
+
|
398 |
+
if self.rope:
|
399 |
+
q, k = self.rope(q, k, offset, time_before_heads=False)
|
400 |
+
|
401 |
+
k, v, pos_k = self._complete_kv(k, v)
|
402 |
+
if self.causal:
|
403 |
+
pos_k = pos_k.view(1, -1)
|
404 |
+
pos_q = offset + torch.arange(T, device=q.device, dtype=torch.long).view(
|
405 |
+
-1, 1
|
406 |
+
)
|
407 |
+
delta = pos_q - pos_k
|
408 |
+
attn_bias = (pos_k >= 0) & (delta >= 0)
|
409 |
+
if self.context is not None:
|
410 |
+
attn_bias = attn_bias & (delta < self.context)
|
411 |
+
else:
|
412 |
+
attn_bias = None
|
413 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias, dropout_p=0.0)
|
414 |
+
|
415 |
+
x = rearrange(x, "b h t d -> b t (h d)")
|
416 |
+
if self.weights_per_step:
|
417 |
+
x = multi_linear(self.weights_per_step, self.out_proj.weight, x, offset_cpu)
|
418 |
+
else:
|
419 |
+
x = self.out_proj(x)
|
420 |
+
if state is not None:
|
421 |
+
state.offset.add_(T)
|
422 |
+
state.offset_cpu += T
|
423 |
+
return x
|
424 |
+
|
425 |
+
|
426 |
+
@dataclass
|
427 |
+
class _LayerState:
|
428 |
+
offset_cpu: int
|
429 |
+
|
430 |
+
def reset(self):
|
431 |
+
self.offset_cpu = 0
|
432 |
+
|
433 |
+
|
434 |
+
class StreamingTransformerLayer(StreamingModule[_LayerState]):
|
435 |
+
"""TransformerLayer with Streaming / Causal support.
|
436 |
+
|
437 |
+
Args:
|
438 |
+
d_model (int): Dimension of the data.
|
439 |
+
num_heads (int): Number of heads.
|
440 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
441 |
+
causal (bool): Causal mask applied automatically.
|
442 |
+
context (int, optional): Receptive field for the causal mask, infinite if None.
|
443 |
+
custom (bool): Use custom MHA implementation, for testing / benchmarking.
|
444 |
+
rope (`RotaryEmbedding`, optional): Rope embedding to use.
|
445 |
+
norm (str): Normalization to use. Currently, only 'layer_norm' is supported.
|
446 |
+
layer_scale (float, optional): If not None, LayerScale will be used with the given value as initial scale.
|
447 |
+
gating (str): if provided, replaces FFN with special gating, like GLU, GSiGLU etc.
|
448 |
+
weights_per_step (int): use different weights per time step. If non zero, should correspond to the
|
449 |
+
number of possible time steps.
|
450 |
+
skip_self_attn: If true, skips the self attention module and the norm
|
451 |
+
device (torch.device, optional): Device on which to initialize.
|
452 |
+
dtype (torch.dtype, optional): dtype to use.
|
453 |
+
"""
|
454 |
+
|
455 |
+
_fsdp_final = True
|
456 |
+
|
457 |
+
def __init__(
|
458 |
+
self,
|
459 |
+
d_model: int,
|
460 |
+
num_heads: int,
|
461 |
+
dim_feedforward: int | list[int] = 2048,
|
462 |
+
causal: bool = False,
|
463 |
+
context: tp.Optional[int] = None,
|
464 |
+
rope: tp.Optional[RotaryEmbedding] = None,
|
465 |
+
norm: str = "layer_norm",
|
466 |
+
layer_scale: tp.Optional[float] = None,
|
467 |
+
gating: str = "none",
|
468 |
+
weights_per_step: int = 0,
|
469 |
+
activation=F.gelu,
|
470 |
+
skip_self_attn: bool = False,
|
471 |
+
device=None,
|
472 |
+
dtype=None,
|
473 |
+
):
|
474 |
+
super().__init__()
|
475 |
+
factory_kwargs = {"device": device, "dtype": dtype}
|
476 |
+
# Redefine self_attn to our streaming multi-head attention
|
477 |
+
attn_kwargs: tp.Dict[str, tp.Any] = {
|
478 |
+
"embed_dim": d_model,
|
479 |
+
"num_heads": num_heads,
|
480 |
+
}
|
481 |
+
if not skip_self_attn:
|
482 |
+
self.self_attn: StreamingMultiheadAttention = StreamingMultiheadAttention(
|
483 |
+
causal=causal,
|
484 |
+
context=context,
|
485 |
+
rope=rope,
|
486 |
+
weights_per_step=weights_per_step,
|
487 |
+
**attn_kwargs, # type: ignore
|
488 |
+
**factory_kwargs, # type: ignore
|
489 |
+
) # type: ignore
|
490 |
+
self.norm1 = create_norm_fn(norm, d_model, **factory_kwargs)
|
491 |
+
self.norm2 = create_norm_fn(norm, d_model, **factory_kwargs)
|
492 |
+
# Redefine feedforward layers to expose bias parameter
|
493 |
+
self.weights_per_step = weights_per_step
|
494 |
+
self.gating: tp.Optional[nn.Module] = None
|
495 |
+
self.linear1: tp.Optional[nn.Module] = None
|
496 |
+
self.linear2: tp.Optional[nn.Module] = None
|
497 |
+
self.activation = activation
|
498 |
+
self.skip_self_attn = skip_self_attn
|
499 |
+
|
500 |
+
if isinstance(dim_feedforward, list):
|
501 |
+
assert dim_feedforward
|
502 |
+
assert len(dim_feedforward) == weights_per_step, (
|
503 |
+
"Length of dim_feedforward must match weights_per_step,"
|
504 |
+
f" got {len(dim_feedforward)} != {weights_per_step}"
|
505 |
+
)
|
506 |
+
if gating == "none":
|
507 |
+
assert (
|
508 |
+
not weights_per_step
|
509 |
+
), "weights_per_step without gating not supported for now."
|
510 |
+
assert not isinstance(
|
511 |
+
dim_feedforward, list
|
512 |
+
), "List dim_feedforward without gating not supported for now."
|
513 |
+
self.linear1 = nn.Linear(
|
514 |
+
d_model, dim_feedforward, bias=False, **factory_kwargs
|
515 |
+
)
|
516 |
+
self.linear2 = nn.Linear(
|
517 |
+
dim_feedforward, d_model, bias=False, **factory_kwargs
|
518 |
+
)
|
519 |
+
else:
|
520 |
+
self.linear1 = None
|
521 |
+
self.linear2 = None
|
522 |
+
if weights_per_step:
|
523 |
+
if isinstance(dim_feedforward, int):
|
524 |
+
dim_feedforward = [dim_feedforward] * weights_per_step
|
525 |
+
assert isinstance(dim_feedforward, list), dim_feedforward
|
526 |
+
self.gating = nn.ModuleList(
|
527 |
+
[
|
528 |
+
make_gating(gating, d_model, dim, **factory_kwargs)
|
529 |
+
for dim in dim_feedforward
|
530 |
+
]
|
531 |
+
)
|
532 |
+
else:
|
533 |
+
assert isinstance(dim_feedforward, int)
|
534 |
+
self.gating = make_gating(
|
535 |
+
gating, d_model, dim_feedforward, **factory_kwargs
|
536 |
+
)
|
537 |
+
|
538 |
+
self.layer_scale_1: nn.Module
|
539 |
+
self.layer_scale_2: nn.Module
|
540 |
+
if layer_scale is None:
|
541 |
+
self.layer_scale_1 = nn.Identity()
|
542 |
+
self.layer_scale_2 = nn.Identity()
|
543 |
+
else:
|
544 |
+
self.layer_scale_1 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
|
545 |
+
self.layer_scale_2 = LayerScale(d_model, layer_scale, **factory_kwargs) # type: ignore
|
546 |
+
|
547 |
+
def _init_streaming_state(self, batch_size: int) -> _LayerState:
|
548 |
+
return _LayerState(offset_cpu=0)
|
549 |
+
|
550 |
+
# feed forward block
|
551 |
+
def _ff_block(self, x: torch.Tensor) -> torch.Tensor:
|
552 |
+
state = self._streaming_state
|
553 |
+
offset = 0
|
554 |
+
if state is not None:
|
555 |
+
offset = state.offset_cpu
|
556 |
+
x_orig = x
|
557 |
+
x = self.norm2(x)
|
558 |
+
if self.gating is None:
|
559 |
+
assert self.linear1 is not None
|
560 |
+
assert self.linear2 is not None
|
561 |
+
update = self.linear2(self.activation(self.linear1(x)))
|
562 |
+
else:
|
563 |
+
if self.weights_per_step:
|
564 |
+
assert isinstance(self.gating, nn.ModuleList)
|
565 |
+
B, T, D = x.shape
|
566 |
+
ys = []
|
567 |
+
for t in range(T):
|
568 |
+
y = self.gating[offset + t](x[:, t : t + 1])
|
569 |
+
ys.append(y)
|
570 |
+
update = torch.cat(ys, dim=1)
|
571 |
+
else:
|
572 |
+
update = self.gating(x)
|
573 |
+
return x_orig + self.layer_scale_2(update)
|
574 |
+
|
575 |
+
def _sa_block(self, x: torch.Tensor):
|
576 |
+
if self.skip_self_attn:
|
577 |
+
return x
|
578 |
+
x_orig = x
|
579 |
+
x = self.norm1(x)
|
580 |
+
update = self.self_attn(x, x, x)
|
581 |
+
return x_orig + self.layer_scale_1(update)
|
582 |
+
|
583 |
+
def forward(self, x: torch.Tensor):
|
584 |
+
with ExitStack() as stack:
|
585 |
+
if x.device.type != 'cuda':
|
586 |
+
stack.enter_context(no_compile())
|
587 |
+
x = self._sa_block(x)
|
588 |
+
x = self._ff_block(x)
|
589 |
+
state = self._streaming_state
|
590 |
+
if state:
|
591 |
+
state.offset_cpu += x.shape[1]
|
592 |
+
return x
|
593 |
+
|
594 |
+
|
595 |
+
@dataclass
|
596 |
+
class _TransformerState:
|
597 |
+
offset: torch.Tensor
|
598 |
+
|
599 |
+
def reset(self):
|
600 |
+
self.offset.zero_()
|
601 |
+
|
602 |
+
|
603 |
+
class StreamingTransformer(StreamingModule[_TransformerState]):
|
604 |
+
"""Transformer with Streaming / Causal support.
|
605 |
+
|
606 |
+
Args:
|
607 |
+
d_model (int): Dimension of the data.
|
608 |
+
num_heads (int): Number of heads.
|
609 |
+
dim_feedforward (int): Intermediate dimension of FF module.
|
610 |
+
causal (bool): Causal mask applied automatically.
|
611 |
+
context (int, optional): Receptive field for the causal mask, infinite if None.
|
612 |
+
layer_scale (float, optional): If not None, LayerScale will be used
|
613 |
+
with the given value as initial scale.
|
614 |
+
positional_embedding (str): Positional embedding strategy (sin, rope, sin_rope, or none).
|
615 |
+
max_period (float): Maximum period of the time embedding.
|
616 |
+
positional_scale (float): Scale of positional embedding, set to 0 to deactivate.
|
617 |
+
layer_class: (subclass of `StreamingTransformerLayer): class to use
|
618 |
+
to initialize the layers, allowing further customization outside of AudioCraft.
|
619 |
+
device (torch.device, optional): Device on which to initialize.
|
620 |
+
dtype (torch.dtype, optional): dtype to use.
|
621 |
+
**kwargs: See `StreamingTransformerLayer`.
|
622 |
+
"""
|
623 |
+
|
624 |
+
def __init__(
|
625 |
+
self,
|
626 |
+
d_model: int,
|
627 |
+
num_heads: int,
|
628 |
+
num_layers: int,
|
629 |
+
dim_feedforward: int | list[int] = 2048,
|
630 |
+
causal: bool = False,
|
631 |
+
context: tp.Optional[int] = None,
|
632 |
+
positional_embedding: str = "sin",
|
633 |
+
max_period: float = 10_000,
|
634 |
+
positional_scale: float = 1.0,
|
635 |
+
betas: tp.Optional[tp.Tuple[float, float]] = None,
|
636 |
+
layer_class: tp.Type[StreamingTransformerLayer] = StreamingTransformerLayer,
|
637 |
+
device=None,
|
638 |
+
dtype=None,
|
639 |
+
**kwargs,
|
640 |
+
):
|
641 |
+
super().__init__()
|
642 |
+
assert d_model % num_heads == 0
|
643 |
+
|
644 |
+
self.positional_embedding = positional_embedding
|
645 |
+
self.max_period = max_period
|
646 |
+
self.positional_scale = positional_scale
|
647 |
+
self.betas = betas
|
648 |
+
|
649 |
+
assert positional_embedding in {"sin", "rope", "sin_rope", "none"}
|
650 |
+
self.rope: tp.Optional[RotaryEmbedding] = None
|
651 |
+
if self.positional_embedding in {"rope", "sin_rope"}:
|
652 |
+
self.rope = RotaryEmbedding(max_period=max_period)
|
653 |
+
|
654 |
+
self.layers = nn.ModuleList()
|
655 |
+
for _ in range(num_layers):
|
656 |
+
self.layers.append(
|
657 |
+
layer_class(
|
658 |
+
d_model=d_model,
|
659 |
+
num_heads=num_heads,
|
660 |
+
dim_feedforward=dim_feedforward,
|
661 |
+
causal=causal,
|
662 |
+
context=context,
|
663 |
+
rope=self.rope,
|
664 |
+
device=device,
|
665 |
+
dtype=dtype,
|
666 |
+
**kwargs,
|
667 |
+
)
|
668 |
+
)
|
669 |
+
|
670 |
+
def _init_streaming_state(self, batch_size: int) -> _TransformerState:
|
671 |
+
device = next(self.parameters()).device
|
672 |
+
return _TransformerState(offset=torch.zeros(1, device=device, dtype=torch.long))
|
673 |
+
|
674 |
+
def forward(self, x: torch.Tensor, *args, **kwargs):
|
675 |
+
B, T, C = x.shape
|
676 |
+
|
677 |
+
state = self._streaming_state
|
678 |
+
if state is None:
|
679 |
+
offset = torch.zeros(1, dtype=torch.long, device=x.device)
|
680 |
+
else:
|
681 |
+
offset = state.offset
|
682 |
+
|
683 |
+
if self.positional_embedding in {"sin", "sin_rope"}:
|
684 |
+
positions = torch.arange(T, device=x.device).view(1, -1, 1)
|
685 |
+
positions = positions + offset.view(-1, 1, 1)
|
686 |
+
pos_emb = create_sin_embedding(
|
687 |
+
positions, C, max_period=self.max_period, dtype=x.dtype
|
688 |
+
)
|
689 |
+
x = x + self.positional_scale * pos_emb
|
690 |
+
|
691 |
+
for layer in self.layers:
|
692 |
+
x = layer(x, *args, **kwargs)
|
693 |
+
|
694 |
+
if state is not None:
|
695 |
+
state.offset.add_(T)
|
696 |
+
return x
|
697 |
+
|
698 |
+
|
699 |
+
class ProjectedTransformer(StreamingContainer):
|
700 |
+
"""Transformer with optional projections of the input and output to different dimensions when needed.
|
701 |
+
Supports multiple outputs.
|
702 |
+
|
703 |
+
Args:
|
704 |
+
input_dimension (int): dimension of the input.
|
705 |
+
output_dimensions (tuple[int]): dimensions of the outputs.
|
706 |
+
d_model (int): inner dimension of the Transformer.
|
707 |
+
conv_layout (bool): If True, expects `[B, C, T]` shaped tensors, otherwise, `[B, T, C]`.
|
708 |
+
Similarly, the output will have the same layout.
|
709 |
+
"""
|
710 |
+
|
711 |
+
def __init__(
|
712 |
+
self,
|
713 |
+
input_dimension: int,
|
714 |
+
output_dimensions: tp.Tuple[int, ...],
|
715 |
+
d_model: int,
|
716 |
+
*,
|
717 |
+
conv_layout: bool = False,
|
718 |
+
**kwargs,
|
719 |
+
):
|
720 |
+
super().__init__()
|
721 |
+
self.transformer = StreamingTransformer(d_model=d_model, **kwargs)
|
722 |
+
self.input_dimension = input_dimension
|
723 |
+
self.output_dimensions = output_dimensions
|
724 |
+
self.conv_layout = conv_layout
|
725 |
+
self.input_proj = None
|
726 |
+
if d_model != input_dimension:
|
727 |
+
self.input_proj = nn.Linear(input_dimension, d_model, bias=False)
|
728 |
+
|
729 |
+
self.output_projs = nn.ModuleList()
|
730 |
+
for output_dimension in output_dimensions:
|
731 |
+
if d_model == output_dimension:
|
732 |
+
self.output_projs.append(nn.Identity())
|
733 |
+
else:
|
734 |
+
self.output_projs.append(
|
735 |
+
nn.Linear(d_model, output_dimension, bias=False)
|
736 |
+
)
|
737 |
+
|
738 |
+
def forward(self, x, *args, **kwargs):
|
739 |
+
if self.conv_layout:
|
740 |
+
x = x.transpose(1, 2)
|
741 |
+
if self.input_proj is not None:
|
742 |
+
x = self.input_proj(x)
|
743 |
+
z = self.transformer(x, *args, **kwargs)
|
744 |
+
ys = []
|
745 |
+
for output_proj in self.output_projs:
|
746 |
+
y = output_proj(z)
|
747 |
+
if self.conv_layout:
|
748 |
+
y = y.transpose(1, 2)
|
749 |
+
ys.append(y)
|
750 |
+
return ys
|
moshi/quantization/__init__.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
"""RVQ."""
|
11 |
+
# flake8: noqa
|
12 |
+
from .vq import ResidualVectorQuantizer, SplitResidualVectorQuantizer
|
13 |
+
from .base import BaseQuantizer, DummyQuantizer, QuantizedResult
|
moshi/quantization/base.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
"""
|
12 |
+
Base class for all quantizers.
|
13 |
+
"""
|
14 |
+
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
import typing as tp
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch import nn
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class QuantizedResult:
|
24 |
+
x: torch.Tensor
|
25 |
+
codes: torch.Tensor
|
26 |
+
bandwidth: torch.Tensor # bandwidth in kb/s used, per batch item.
|
27 |
+
penalty: tp.Optional[torch.Tensor] = None
|
28 |
+
metrics: dict = field(default_factory=dict)
|
29 |
+
|
30 |
+
|
31 |
+
class BaseQuantizer(nn.Module):
|
32 |
+
"""Base class for quantizers."""
|
33 |
+
|
34 |
+
def __init__(self):
|
35 |
+
super().__init__()
|
36 |
+
self._ema_frozen = False
|
37 |
+
|
38 |
+
def forward(self, x: torch.Tensor, frame_rate: int) -> QuantizedResult:
|
39 |
+
"""
|
40 |
+
Given input tensor x, returns first the quantized (or approximately quantized)
|
41 |
+
representation along with quantized codes, bandwidth, and any penalty term for the loss.
|
42 |
+
Finally, this returns a dict of metrics to update logging etc.
|
43 |
+
Frame rate must be passed so that the bandwidth is properly computed.
|
44 |
+
"""
|
45 |
+
raise NotImplementedError()
|
46 |
+
|
47 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
48 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth."""
|
49 |
+
raise NotImplementedError()
|
50 |
+
|
51 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
52 |
+
"""Decode the given codes to the quantized representation."""
|
53 |
+
raise NotImplementedError()
|
54 |
+
|
55 |
+
@property
|
56 |
+
def cardinality(self) -> int:
|
57 |
+
"""Cardinality of each codebook."""
|
58 |
+
raise NotImplementedError()
|
59 |
+
|
60 |
+
@property
|
61 |
+
def total_codebooks(self) -> int:
|
62 |
+
"""Total number of codebooks."""
|
63 |
+
raise NotImplementedError()
|
64 |
+
|
65 |
+
@property
|
66 |
+
def num_codebooks(self) -> int:
|
67 |
+
"""Number of active codebooks."""
|
68 |
+
raise NotImplementedError()
|
69 |
+
|
70 |
+
@property
|
71 |
+
def semantic_quantizer(self) -> 'BaseQuantizer':
|
72 |
+
"""This returns the quantizer that models the first level of the hierarchy (typically semantic).
|
73 |
+
|
74 |
+
In this case, it's the quantizer itself.
|
75 |
+
"""
|
76 |
+
return self
|
77 |
+
|
78 |
+
@property
|
79 |
+
def acoustic_quantizer(self) -> 'BaseQuantizer':
|
80 |
+
"""This returns the quantizer that models the higher levels of the hierarchy (typically acoustic).
|
81 |
+
|
82 |
+
In this case, it's the quantizer itself.
|
83 |
+
"""
|
84 |
+
return self
|
85 |
+
|
86 |
+
def set_num_codebooks(self, n: int) -> None:
|
87 |
+
"""Set the number of active codebooks."""
|
88 |
+
raise NotImplementedError()
|
89 |
+
|
90 |
+
@property
|
91 |
+
def ema_frozen(self) -> bool:
|
92 |
+
"""Whether to apply ema to the codebooks."""
|
93 |
+
return self._ema_frozen
|
94 |
+
|
95 |
+
def ema_frozen_(self, ema_frozen: bool) -> None:
|
96 |
+
"""Set whether ema should be applied to the codebooks."""
|
97 |
+
self._ema_frozen = ema_frozen
|
98 |
+
|
99 |
+
|
100 |
+
class DummyQuantizer(BaseQuantizer):
|
101 |
+
"""Fake quantizer that actually does not perform any quantization."""
|
102 |
+
|
103 |
+
def __init__(
|
104 |
+
self,
|
105 |
+
dimension: int,
|
106 |
+
input_dimension: tp.Optional[int] = None,
|
107 |
+
output_dimension: tp.Optional[int] = None,
|
108 |
+
):
|
109 |
+
super().__init__()
|
110 |
+
self.dimension = dimension
|
111 |
+
self.input_dimension = input_dimension or dimension
|
112 |
+
self.output_dimension = output_dimension or dimension
|
113 |
+
self.input_proj: torch.nn.Module
|
114 |
+
self.output_proj: torch.nn.Module
|
115 |
+
if self.input_dimension == self.dimension:
|
116 |
+
self.input_proj = torch.nn.Identity()
|
117 |
+
else:
|
118 |
+
self.input_proj = torch.nn.Conv1d(
|
119 |
+
self.input_dimension, self.dimension, 1, bias=False
|
120 |
+
)
|
121 |
+
if self.input_dimension == self.dimension:
|
122 |
+
self.output_proj = torch.nn.Identity()
|
123 |
+
else:
|
124 |
+
self.output_proj = torch.nn.Conv1d(
|
125 |
+
self.dimension, self.output_dimension, 1, bias=False
|
126 |
+
)
|
127 |
+
|
128 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
129 |
+
q = x.unsqueeze(1)
|
130 |
+
x = self.output_proj(self.input_proj(x))
|
131 |
+
return QuantizedResult(
|
132 |
+
x, q, torch.tensor(q.numel() * 32 * frame_rate / 1000 / len(x)).to(x)
|
133 |
+
)
|
134 |
+
|
135 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
136 |
+
"""Encode a given input tensor with the specified sample rate at the given bandwidth.
|
137 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
138 |
+
to the input and resulting quantized representation as no quantization is done.
|
139 |
+
"""
|
140 |
+
x = self.input_proj(x)
|
141 |
+
return x.unsqueeze(1)
|
142 |
+
|
143 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
144 |
+
"""Decode the given codes to the quantized representation.
|
145 |
+
In the case of the DummyQuantizer, the codes are actually identical
|
146 |
+
to the input and resulting quantized representation as no quantization is done.
|
147 |
+
"""
|
148 |
+
y = codes.squeeze(1)
|
149 |
+
return self.output_proj(y)
|
150 |
+
|
151 |
+
@property
|
152 |
+
def total_codebooks(self):
|
153 |
+
"""Total number of codebooks."""
|
154 |
+
return 1
|
155 |
+
|
156 |
+
@property
|
157 |
+
def num_codebooks(self):
|
158 |
+
"""Total number of codebooks."""
|
159 |
+
return self.total_codebooks
|
160 |
+
|
161 |
+
def set_num_codebooks(self, n: int):
|
162 |
+
"""Set the number of active codebooks."""
|
163 |
+
raise AttributeError(
|
164 |
+
"Cannot override the number of codebooks for the dummy quantizer"
|
165 |
+
)
|
166 |
+
|
167 |
+
@property
|
168 |
+
def cardinality(self) -> int:
|
169 |
+
"""Cardinality of each codebook."""
|
170 |
+
return 1
|
moshi/quantization/core_vq.py
ADDED
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
import typing as tp
|
12 |
+
|
13 |
+
from einops import rearrange
|
14 |
+
import torch
|
15 |
+
from torch import nn
|
16 |
+
from torch import distributed
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
class _CodebookForwardResult(tp.NamedTuple):
|
21 |
+
quantized: torch.Tensor
|
22 |
+
codes: torch.Tensor
|
23 |
+
metrics: tp.Dict[str, torch.Tensor]
|
24 |
+
|
25 |
+
|
26 |
+
class _VQForwardResult(tp.NamedTuple):
|
27 |
+
quantized: torch.Tensor
|
28 |
+
codes: torch.Tensor
|
29 |
+
loss: torch.Tensor
|
30 |
+
metrics: tp.Dict[str, torch.Tensor]
|
31 |
+
|
32 |
+
|
33 |
+
def _ema_inplace(moving_avg: torch.Tensor, new: torch.Tensor, decay: float) -> None:
|
34 |
+
moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay))
|
35 |
+
|
36 |
+
|
37 |
+
def _uniform_init(*shape: int) -> torch.Tensor:
|
38 |
+
t = torch.empty(shape)
|
39 |
+
nn.init.kaiming_uniform_(t)
|
40 |
+
return t
|
41 |
+
|
42 |
+
|
43 |
+
def _sample_vectors(samples: torch.Tensor, num: int) -> torch.Tensor:
|
44 |
+
num_samples, device = samples.shape[0], samples.device
|
45 |
+
|
46 |
+
if num_samples >= num:
|
47 |
+
indices = torch.randperm(num_samples, device=device)[:num]
|
48 |
+
else:
|
49 |
+
indices = torch.randint(0, num_samples, (num,), device=device)
|
50 |
+
|
51 |
+
return samples[indices]
|
52 |
+
|
53 |
+
|
54 |
+
def _compute_entropy(usage: torch.Tensor) -> torch.Tensor:
|
55 |
+
# Usage is some unnormalized distribution.
|
56 |
+
proba = usage / usage.sum()
|
57 |
+
p_log_p = torch.where(
|
58 |
+
proba == 0, zero_scalar(usage.device), proba * torch.log(proba)
|
59 |
+
)
|
60 |
+
return -p_log_p.sum()
|
61 |
+
|
62 |
+
|
63 |
+
def _is_distributed() -> bool:
|
64 |
+
# Checks if we need to use distributed routines.
|
65 |
+
return distributed.is_initialized() and distributed.get_world_size() > 1
|
66 |
+
|
67 |
+
|
68 |
+
def zero_scalar(device) -> torch.Tensor:
|
69 |
+
"""Returns a 0. value on the given device without introducing a synchronization point."""
|
70 |
+
return torch.zeros([1], device=device)[0]
|
71 |
+
|
72 |
+
|
73 |
+
class EuclideanCodebook(nn.Module):
|
74 |
+
"""Codebook with Euclidean distance.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
dim (int): Dimension.
|
78 |
+
codebook_size (int): Codebook size.
|
79 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
80 |
+
epsilon (float): Epsilon value for numerical stability.
|
81 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
82 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
83 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
84 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
85 |
+
to avoid the centroid getting replaced too quickly.
|
86 |
+
check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
|
87 |
+
This is to avoid too many synchronization points.
|
88 |
+
|
89 |
+
Buffers:
|
90 |
+
cluster_usage (torch.Tensor): EMA of the cluster usage per batch, e.g. this will
|
91 |
+
be dependent on the batch size etc.
|
92 |
+
embedding_sum (torch.Tensor): EMA of the sum of the assigned points to each cluster.
|
93 |
+
In particular, this can be normalized by `cluster_usage` to obtain the
|
94 |
+
actual cluster centroids.
|
95 |
+
"""
|
96 |
+
|
97 |
+
def __init__(
|
98 |
+
self,
|
99 |
+
dim: int,
|
100 |
+
codebook_size: int,
|
101 |
+
decay: float = 0.99,
|
102 |
+
epsilon: float = 1e-5,
|
103 |
+
threshold_usage_ratio: float = 0.1,
|
104 |
+
replaced_usage_ratio: float = 1.0,
|
105 |
+
check_unused_every: int = 5,
|
106 |
+
):
|
107 |
+
super().__init__()
|
108 |
+
self.decay = decay
|
109 |
+
embedding = torch.zeros(codebook_size, dim)
|
110 |
+
|
111 |
+
self.dim = dim
|
112 |
+
self.codebook_size = codebook_size
|
113 |
+
|
114 |
+
self.epsilon = epsilon
|
115 |
+
self.threshold_usage_ratio = threshold_usage_ratio
|
116 |
+
self.replaced_usage_ratio = replaced_usage_ratio
|
117 |
+
self.check_unused_every = check_unused_every
|
118 |
+
self._next_unused_check = check_unused_every
|
119 |
+
|
120 |
+
self.register_buffer("_initialized", torch.tensor([False], dtype=torch.float))
|
121 |
+
self.register_buffer("cluster_usage", torch.ones(codebook_size))
|
122 |
+
self.register_buffer("embedding_sum", embedding)
|
123 |
+
self.register_buffer("_embedding", None, persistent=False)
|
124 |
+
self._cached_initialized = False
|
125 |
+
|
126 |
+
def _load_from_state_dict(self, state_dict, prefix, *args, **kwargs) -> None:
|
127 |
+
# Mapping old names to new names
|
128 |
+
mappings = {
|
129 |
+
"inited": "_initialized",
|
130 |
+
"cluster_size": "cluster_usage",
|
131 |
+
"embed_avg": "embedding_sum",
|
132 |
+
"embed_sum": "embedding_sum",
|
133 |
+
}
|
134 |
+
for old_name, new_name in mappings.items():
|
135 |
+
old_name = prefix + old_name
|
136 |
+
if old_name in state_dict:
|
137 |
+
value = state_dict.pop(old_name)
|
138 |
+
if new_name is not None:
|
139 |
+
state_dict[prefix + new_name] = value
|
140 |
+
super()._load_from_state_dict(state_dict, prefix, *args, **kwargs)
|
141 |
+
|
142 |
+
@property
|
143 |
+
def embedding(self) -> torch.Tensor:
|
144 |
+
if self._embedding is None:
|
145 |
+
embedding = (
|
146 |
+
self.embedding_sum / self.cluster_usage.clamp(min=self.epsilon)[:, None]
|
147 |
+
)
|
148 |
+
self.register_buffer("_embedding", embedding, persistent=False)
|
149 |
+
return embedding
|
150 |
+
return self._embedding
|
151 |
+
|
152 |
+
def _broadcast_buffers(self) -> None:
|
153 |
+
if _is_distributed():
|
154 |
+
for buffer in self.buffers():
|
155 |
+
distributed.broadcast(buffer, 0)
|
156 |
+
|
157 |
+
def _replace_expired_codes(self, samples: torch.Tensor, mask: torch.Tensor) -> None:
|
158 |
+
# Replaces expired centroids, as indicated by `mask` (a true value indicate the code needs to be replaced).
|
159 |
+
# The new codes are sampled from the batch `samples`.
|
160 |
+
new_vectors = _sample_vectors(samples, self.codebook_size)
|
161 |
+
replace_cluster_usage = (
|
162 |
+
self.replaced_usage_ratio * self.cluster_usage.sum() / self.codebook_size
|
163 |
+
)
|
164 |
+
self.embedding_sum[:] = torch.where(
|
165 |
+
mask[:, None], replace_cluster_usage * new_vectors, self.embedding_sum
|
166 |
+
)
|
167 |
+
self.cluster_usage[:] = torch.where(
|
168 |
+
mask, replace_cluster_usage, self.cluster_usage
|
169 |
+
)
|
170 |
+
|
171 |
+
def _reshape_input(self, x: torch.Tensor) -> torch.Tensor:
|
172 |
+
# Flattens all the dimensions but the last one, e.g. return a vector of shape `[N, D]`.
|
173 |
+
x = rearrange(x, "... d -> (...) d")
|
174 |
+
return x
|
175 |
+
|
176 |
+
def _reshape_codes(self, codes: torch.Tensor, shape: torch.Size) -> torch.Tensor:
|
177 |
+
return codes.view(*shape[:-1])
|
178 |
+
|
179 |
+
def _quantize(self, x: torch.Tensor) -> torch.Tensor:
|
180 |
+
# Projects each vector in `x` over the nearest centroid and return its index.
|
181 |
+
# `x` should be `[N, D]` with `N` the number of input vectors and `D` the dimension.
|
182 |
+
assert x.dim() == 2
|
183 |
+
dists = torch.cdist(x[None], self.embedding[None], p=2)[0]
|
184 |
+
codes = dists.argmin(dim=-1)
|
185 |
+
return codes
|
186 |
+
|
187 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
188 |
+
"""Given a tensor `x` of shape `[*, D]`, returns a tensor of integer codes of shape `[*]`.
|
189 |
+
The codes are defined as the indexes of the centroids nearest to each vector in `x`.
|
190 |
+
"""
|
191 |
+
assert x.dtype.is_floating_point, f"Input should be floats, got {x.dtype}"
|
192 |
+
shape = x.shape
|
193 |
+
x = self._reshape_input(x)
|
194 |
+
codes = self._quantize(x)
|
195 |
+
codes = self._reshape_codes(codes, shape)
|
196 |
+
return codes
|
197 |
+
|
198 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
199 |
+
"""Given a tensor of codes of shape `[*]`, returns a tensor of shape `[*, D]`,
|
200 |
+
corresponding to the centroids associated to each code index.
|
201 |
+
"""
|
202 |
+
assert (
|
203 |
+
not codes.dtype.is_floating_point
|
204 |
+
), f"Codes should be integers, got {codes.dtype}"
|
205 |
+
quantized = F.embedding(codes, self.embedding)
|
206 |
+
return quantized
|
207 |
+
|
208 |
+
def forward(
|
209 |
+
self, x: torch.Tensor, initialize: bool = True
|
210 |
+
) -> _CodebookForwardResult:
|
211 |
+
shape = x.shape
|
212 |
+
x = self._reshape_input(x)
|
213 |
+
|
214 |
+
flat_codes = self._quantize(x)
|
215 |
+
codes = self._reshape_codes(flat_codes, shape)
|
216 |
+
quantized = self.decode(codes)
|
217 |
+
metrics: tp.Dict[str, torch.Tensor] = {}
|
218 |
+
|
219 |
+
return _CodebookForwardResult(quantized, codes, metrics)
|
220 |
+
|
221 |
+
|
222 |
+
class VectorQuantization(nn.Module):
|
223 |
+
"""Vector quantization implementation.
|
224 |
+
Currently supports only euclidean distance.
|
225 |
+
|
226 |
+
Args:
|
227 |
+
dim (int): Dimension
|
228 |
+
codebook_size (int): Codebook size
|
229 |
+
codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim.
|
230 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
231 |
+
epsilon (float): Epsilon value for numerical stability.
|
232 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
233 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
234 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
235 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
236 |
+
to avoid the centroid getting replaced too quickly.
|
237 |
+
check_unused_every (int): Check for unused centroids every `check_unused_every` iterations.
|
238 |
+
This is to avoid too many synchronization points.
|
239 |
+
"""
|
240 |
+
|
241 |
+
def __init__(
|
242 |
+
self,
|
243 |
+
dim: int,
|
244 |
+
codebook_size: int,
|
245 |
+
codebook_dim: tp.Optional[int] = None,
|
246 |
+
decay: float = 0.99,
|
247 |
+
epsilon: float = 1e-5,
|
248 |
+
threshold_usage_ratio: float = 0.1,
|
249 |
+
**kwargs,
|
250 |
+
):
|
251 |
+
super().__init__()
|
252 |
+
if codebook_dim is None:
|
253 |
+
codebook_dim = dim
|
254 |
+
|
255 |
+
requires_projection = codebook_dim != dim
|
256 |
+
self.project_in = (
|
257 |
+
nn.Linear(dim, codebook_dim) if requires_projection else nn.Identity()
|
258 |
+
)
|
259 |
+
self.project_out = (
|
260 |
+
nn.Linear(codebook_dim, dim) if requires_projection else nn.Identity()
|
261 |
+
)
|
262 |
+
self.epsilon = epsilon
|
263 |
+
self._codebook = EuclideanCodebook(
|
264 |
+
dim=codebook_dim,
|
265 |
+
codebook_size=codebook_size,
|
266 |
+
decay=decay,
|
267 |
+
epsilon=epsilon,
|
268 |
+
threshold_usage_ratio=threshold_usage_ratio,
|
269 |
+
**kwargs,
|
270 |
+
)
|
271 |
+
self.codebook_size = codebook_size
|
272 |
+
|
273 |
+
@property
|
274 |
+
def embedding(self):
|
275 |
+
return self._codebook.embedding
|
276 |
+
|
277 |
+
def _rearrange_input(self, x):
|
278 |
+
x = rearrange(x, "b d n -> b n d")
|
279 |
+
return x
|
280 |
+
|
281 |
+
def _rearrange_output(self, quantized):
|
282 |
+
quantized = rearrange(quantized, "b n d -> b d n")
|
283 |
+
return quantized
|
284 |
+
|
285 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
286 |
+
"""Encodes `x` into discrete integer codes."""
|
287 |
+
x = self._rearrange_input(x)
|
288 |
+
x = self.project_in(x)
|
289 |
+
codes = self._codebook.encode(x)
|
290 |
+
return codes
|
291 |
+
|
292 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
293 |
+
"""Converts integer codes into quantized vectors."""
|
294 |
+
quantized = self._codebook.decode(codes)
|
295 |
+
quantized = self.project_out(quantized)
|
296 |
+
quantized = self._rearrange_output(quantized)
|
297 |
+
return quantized
|
298 |
+
|
299 |
+
def forward(self, x: torch.Tensor, initialize: bool = True) -> _VQForwardResult:
|
300 |
+
x = self._rearrange_input(x)
|
301 |
+
quantized, codes, metrics = self._codebook(x, initialize=initialize)
|
302 |
+
|
303 |
+
loss = zero_scalar(x.device)
|
304 |
+
|
305 |
+
quantized = self.project_out(quantized)
|
306 |
+
quantized = self._rearrange_output(quantized)
|
307 |
+
|
308 |
+
return _VQForwardResult(quantized, codes, loss, metrics)
|
309 |
+
|
310 |
+
|
311 |
+
class ResidualVectorQuantization(nn.Module):
|
312 |
+
"""Residual vector quantization implementation.
|
313 |
+
|
314 |
+
Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf
|
315 |
+
"""
|
316 |
+
|
317 |
+
def __init__(self, *, num_quantizers: int, codebook_offset: int, **kwargs):
|
318 |
+
super().__init__()
|
319 |
+
self.layers = nn.ModuleList(
|
320 |
+
[VectorQuantization(**kwargs) for _ in range(num_quantizers)]
|
321 |
+
)
|
322 |
+
self.codebook_offset = codebook_offset
|
323 |
+
|
324 |
+
def forward(
|
325 |
+
self, x: torch.Tensor, n_q: tp.Optional[int] = None
|
326 |
+
) -> _VQForwardResult:
|
327 |
+
"""
|
328 |
+
Args:
|
329 |
+
x (torch.Tensor): input tensor to quantize, of shape `[B, C, T]`.
|
330 |
+
n_q (int or None): if provided, number of codebook levels to use in RVQ.
|
331 |
+
"""
|
332 |
+
|
333 |
+
quantized_out = zero_scalar(x.device)
|
334 |
+
residual = x
|
335 |
+
|
336 |
+
all_losses = []
|
337 |
+
all_codes = []
|
338 |
+
all_metrics: tp.Dict[str, torch.Tensor] = {}
|
339 |
+
|
340 |
+
n_q = n_q or len(self.layers)
|
341 |
+
previous_layer_is_initialized = True
|
342 |
+
|
343 |
+
for i, layer in enumerate(self.layers[:n_q]): # type: ignore
|
344 |
+
quantized, codes, loss, metrics = layer(
|
345 |
+
residual, initialize=previous_layer_is_initialized
|
346 |
+
)
|
347 |
+
|
348 |
+
quantized = quantized.detach()
|
349 |
+
residual = residual - quantized
|
350 |
+
quantized_out = quantized_out + quantized
|
351 |
+
|
352 |
+
all_codes.append(codes)
|
353 |
+
all_losses.append(loss)
|
354 |
+
|
355 |
+
for key, value in metrics.items():
|
356 |
+
if key in all_metrics:
|
357 |
+
all_metrics[key] += value / n_q
|
358 |
+
else:
|
359 |
+
all_metrics[key] = value / n_q
|
360 |
+
all_metrics[key + f"_{i + self.codebook_offset}"] = value
|
361 |
+
|
362 |
+
out_losses, out_codes = map(torch.stack, (all_losses, all_codes))
|
363 |
+
return _VQForwardResult(quantized_out, out_codes, out_losses, all_metrics)
|
364 |
+
|
365 |
+
def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor:
|
366 |
+
"""Encodes `x` into discrete integer codes. If `n_q` is provided, only uses the first `n_q` codebook levels."""
|
367 |
+
residual = x
|
368 |
+
all_indices = []
|
369 |
+
n_q = n_q or len(self.layers)
|
370 |
+
for layer in self.layers[:n_q]: # type: ignore
|
371 |
+
indices = layer.encode(residual)
|
372 |
+
quantized = layer.decode(indices)
|
373 |
+
residual = residual - quantized
|
374 |
+
all_indices.append(indices)
|
375 |
+
out_indices = torch.stack(all_indices)
|
376 |
+
return out_indices
|
377 |
+
|
378 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
379 |
+
"""Converts the integer codes into quantized vectors."""
|
380 |
+
quantized = zero_scalar(codes.device)
|
381 |
+
for idx, layer_codes in enumerate(codes):
|
382 |
+
layer = self.layers[idx]
|
383 |
+
quantized = quantized + layer.decode(layer_codes)
|
384 |
+
return quantized
|
moshi/quantization/vq.py
ADDED
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
import math
|
12 |
+
import typing as tp
|
13 |
+
|
14 |
+
import torch
|
15 |
+
|
16 |
+
from .base import BaseQuantizer, QuantizedResult
|
17 |
+
from .core_vq import ResidualVectorQuantization
|
18 |
+
|
19 |
+
|
20 |
+
class ResidualVectorQuantizer(BaseQuantizer):
|
21 |
+
"""Residual Vector Quantizer.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
dimension (int): Dimension of the codebooks.
|
25 |
+
input_dimension (None or int): dimension of the input, defaults to `dimension` if not provided.
|
26 |
+
output_dimension (None or int): dimension of the output, defaults to `dimension` if not provided.
|
27 |
+
n_q (int): Number of vector quantizers used.
|
28 |
+
q_dropout (bool): Random quantizer drop out at train time.
|
29 |
+
no_quantization_rate (float): Gives the probability of applying no quantization at all
|
30 |
+
at train time. The RVQ codebooks will still get the input value to learn the proper codebook.
|
31 |
+
bins (int): Codebook size.
|
32 |
+
decay (float): Decay for exponential moving average over the codebooks.
|
33 |
+
threshold_usage_ratio (float): Defines the threshold for the cluster usage under which a centroid
|
34 |
+
is replaced. This is expressed as a fraction of the usage a centroid would get under
|
35 |
+
a uniform distribution, so that it doesn't depend on the batch size etc.
|
36 |
+
replaced_usage_ratio (float): When replacing a centroid, use this as an initial centroid usage,
|
37 |
+
to avoid the centroid getting replaced too quickly.
|
38 |
+
codebook_offset (int): Offset to use for the codebook indices. This is useful when using multiple quantizers
|
39 |
+
such as in SplitResidualVectorQuantizer.
|
40 |
+
force_projection (bool): Whether to force input and output projections even when dimension is constant.
|
41 |
+
generator_seed (int or None): seed used to initialize the RNG used for no quantization.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def __init__(
|
45 |
+
self,
|
46 |
+
dimension: int = 128,
|
47 |
+
input_dimension: tp.Optional[int] = None,
|
48 |
+
output_dimension: tp.Optional[int] = None,
|
49 |
+
n_q: int = 8,
|
50 |
+
q_dropout: bool = False,
|
51 |
+
q_first_only_proba: float = 0.0,
|
52 |
+
no_quantization_rate: float = 0.0,
|
53 |
+
bins: int = 1024,
|
54 |
+
decay: float = 0.99,
|
55 |
+
threshold_usage_ratio: float = 0.1,
|
56 |
+
replaced_usage_ratio: float = 1.0,
|
57 |
+
codebook_offset: int = 0,
|
58 |
+
force_projection: bool = False,
|
59 |
+
generator_seed: tp.Optional[int] = None,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.max_n_q = n_q
|
63 |
+
self.n_q = n_q
|
64 |
+
self.q_dropout = q_dropout
|
65 |
+
self.no_quantization_rate = no_quantization_rate
|
66 |
+
self.q_first_only_proba = q_first_only_proba
|
67 |
+
self.dimension = dimension
|
68 |
+
self.input_dimension = input_dimension or dimension
|
69 |
+
self.output_dimension = output_dimension or dimension
|
70 |
+
self.bins = bins
|
71 |
+
self.decay = decay
|
72 |
+
self.input_proj: torch.nn.Module
|
73 |
+
self.output_proj: torch.nn.Module
|
74 |
+
self.generator = None
|
75 |
+
if generator_seed is not None:
|
76 |
+
self.generator = torch.Generator(
|
77 |
+
device="cuda" if torch.cuda.is_available() else "cpu"
|
78 |
+
)
|
79 |
+
self.generator.manual_seed(generator_seed)
|
80 |
+
if self.input_dimension == self.dimension and not force_projection:
|
81 |
+
self.input_proj = torch.nn.Identity()
|
82 |
+
else:
|
83 |
+
self.input_proj = torch.nn.Conv1d(
|
84 |
+
self.input_dimension, self.dimension, 1, bias=False
|
85 |
+
)
|
86 |
+
if self.output_dimension == self.dimension and not force_projection:
|
87 |
+
self.output_proj = torch.nn.Identity()
|
88 |
+
else:
|
89 |
+
self.output_proj = torch.nn.Conv1d(
|
90 |
+
self.dimension, self.output_dimension, 1, bias=False
|
91 |
+
)
|
92 |
+
self.vq = ResidualVectorQuantization(
|
93 |
+
dim=self.dimension,
|
94 |
+
codebook_size=self.bins,
|
95 |
+
num_quantizers=self.n_q,
|
96 |
+
decay=self.decay,
|
97 |
+
threshold_usage_ratio=threshold_usage_ratio,
|
98 |
+
replaced_usage_ratio=replaced_usage_ratio,
|
99 |
+
codebook_offset=codebook_offset,
|
100 |
+
)
|
101 |
+
|
102 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
103 |
+
"""
|
104 |
+
Args:
|
105 |
+
x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
|
106 |
+
frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
|
107 |
+
the bandwidth.
|
108 |
+
|
109 |
+
Returns:
|
110 |
+
QuantizedResult: Quantized result with the following attributes:
|
111 |
+
- `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
|
112 |
+
- `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
|
113 |
+
- `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
|
114 |
+
- `penalty` (torch.Tensor): Commitment loss.
|
115 |
+
- `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
|
116 |
+
"""
|
117 |
+
n_q = self.n_q
|
118 |
+
x = self.input_proj(x)
|
119 |
+
|
120 |
+
bw_per_q = math.log2(self.bins) * frame_rate / 1000
|
121 |
+
quantized, codes, commit_loss, metrics = self.vq(x, n_q=n_q)
|
122 |
+
B, _, _ = quantized.shape
|
123 |
+
quantized = self.output_proj(quantized)
|
124 |
+
codes = codes.transpose(0, 1)
|
125 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
126 |
+
bw = torch.tensor(n_q * bw_per_q).to(x)
|
127 |
+
return QuantizedResult(
|
128 |
+
quantized, codes, bw, penalty=torch.mean(commit_loss), metrics=metrics
|
129 |
+
)
|
130 |
+
|
131 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
132 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
133 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
134 |
+
and returns indices for each quantizer.
|
135 |
+
"""
|
136 |
+
n_q = self.n_q
|
137 |
+
if x.shape[-1] == 0:
|
138 |
+
return torch.empty((x.shape[0], n_q, 0), device=x.device, dtype=torch.int64)
|
139 |
+
|
140 |
+
x = self.input_proj(x)
|
141 |
+
codes = self.vq.encode(x, n_q=n_q)
|
142 |
+
codes = codes.transpose(0, 1)
|
143 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
144 |
+
return codes
|
145 |
+
|
146 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
147 |
+
"""Decode the given codes to the quantized representation."""
|
148 |
+
# codes is [B, K, T], with T frames, K nb of codebooks, vq.decode expects [K, B, T].
|
149 |
+
codes = codes.transpose(0, 1)
|
150 |
+
quantized = self.vq.decode(codes)
|
151 |
+
quantized = self.output_proj(quantized)
|
152 |
+
return quantized
|
153 |
+
|
154 |
+
@property
|
155 |
+
def total_codebooks(self):
|
156 |
+
return self.max_n_q
|
157 |
+
|
158 |
+
@property
|
159 |
+
def num_codebooks(self):
|
160 |
+
return self.n_q
|
161 |
+
|
162 |
+
def set_num_codebooks(self, n: int):
|
163 |
+
assert n >= 0 and n <= self.max_n_q
|
164 |
+
self.n_q = n
|
165 |
+
|
166 |
+
@property
|
167 |
+
def cardinality(self) -> int:
|
168 |
+
return self.bins
|
169 |
+
|
170 |
+
|
171 |
+
class SplitResidualVectorQuantizer(BaseQuantizer):
|
172 |
+
"""Residual Vector Quantizer with separate projections for the first quantizer and the rest.
|
173 |
+
|
174 |
+
Args:
|
175 |
+
n_q (int): Number of residual vector quantizers used.
|
176 |
+
n_semantic_q (int): Number of residual vector quantizers used for the semantic quantizer.
|
177 |
+
no_quantization_mode (str): if 'true_skip', when doing no quantization, the input will not go
|
178 |
+
through the sub quantizers. If `independent`, independent decisions are taken by
|
179 |
+
the semantic and acoustic quantizers. If `same` (the default), the same decision is taken by both.
|
180 |
+
**kwargs: Arguments to the constructor of `ResidualVectorQuantizer` that are shared between both.
|
181 |
+
"""
|
182 |
+
|
183 |
+
def __init__(
|
184 |
+
self,
|
185 |
+
*,
|
186 |
+
n_q: int = 8,
|
187 |
+
no_quantization_rate: float = 0.0,
|
188 |
+
no_quantization_mode: str = "same",
|
189 |
+
n_q_semantic: int = 1,
|
190 |
+
**kwargs,
|
191 |
+
):
|
192 |
+
super().__init__()
|
193 |
+
assert n_q > n_q_semantic, (
|
194 |
+
f"Number of quantizers {n_q} must be larger "
|
195 |
+
f"than the number of semantic quantizers {n_q_semantic}."
|
196 |
+
)
|
197 |
+
self.max_n_q = n_q
|
198 |
+
self.n_q_semantic = n_q_semantic
|
199 |
+
self.n_q_acoustic = n_q - n_q_semantic
|
200 |
+
if no_quantization_mode == "true_skip":
|
201 |
+
self.no_quantization_rate = no_quantization_rate
|
202 |
+
# Setting to zero for the underlying RVQ.
|
203 |
+
no_quantization_rate = 0.0
|
204 |
+
else:
|
205 |
+
self.no_quantization_rate = 0.0
|
206 |
+
if no_quantization_mode == "same":
|
207 |
+
kwargs["generator_seed"] = 1234
|
208 |
+
kwargs["no_quantization_rate"] = no_quantization_rate
|
209 |
+
q_dropout = kwargs.pop("q_dropout", False)
|
210 |
+
self.rvq_first = ResidualVectorQuantizer(
|
211 |
+
n_q=n_q_semantic, force_projection=True, q_dropout=False, **kwargs
|
212 |
+
)
|
213 |
+
self.rvq_rest = ResidualVectorQuantizer(
|
214 |
+
n_q=n_q - n_q_semantic,
|
215 |
+
codebook_offset=1,
|
216 |
+
force_projection=True,
|
217 |
+
q_dropout=q_dropout,
|
218 |
+
**kwargs,
|
219 |
+
)
|
220 |
+
if no_quantization_mode == "true_skip":
|
221 |
+
assert self.rvq_first.input_dimension == self.rvq_first.output_dimension
|
222 |
+
assert self.rvq_rest.input_dimension == self.rvq_rest.output_dimension
|
223 |
+
|
224 |
+
def _renorm_and_add(
|
225 |
+
self,
|
226 |
+
first_val: torch.Tensor,
|
227 |
+
rest_val: torch.Tensor,
|
228 |
+
n_q_semantic: int,
|
229 |
+
n_q_acoustic: int,
|
230 |
+
):
|
231 |
+
"""Renormalizes values from `rvq_first` and `rvq_rest` and adds them.
|
232 |
+
|
233 |
+
This allows correcting statistics that are normalized by the number of quantizers. To renormalize, we use the
|
234 |
+
number of quantizers that are actually used, e.g. taking into account quantizer dropout.
|
235 |
+
"""
|
236 |
+
n_q = n_q_semantic + n_q_acoustic
|
237 |
+
renorm_first_val = first_val * n_q_semantic / n_q
|
238 |
+
renorm_rest_val = rest_val * n_q_acoustic / n_q
|
239 |
+
return renorm_first_val + renorm_rest_val
|
240 |
+
|
241 |
+
def forward(self, x: torch.Tensor, frame_rate: int):
|
242 |
+
"""
|
243 |
+
Args:
|
244 |
+
x (torch.Tensor): Input tensor of shape [B, C, T] with `C` number of channels.
|
245 |
+
frame_rate (int): frame rate of the input (e.g `T = frame_rate * duration`), used to compute
|
246 |
+
the bandwidth.
|
247 |
+
|
248 |
+
Returns:
|
249 |
+
QuantizedResult: Quantized result with the following attributes:
|
250 |
+
- `x` (torch.Tensor): Quantized tensor of shape [B, C, T].
|
251 |
+
- `codes` (torch.Tensor): Quantized codes of shape [B, K, T] with `K` number of codebooks.
|
252 |
+
- `bw` (torch.Tensor): Bandwidth of the quantized tensor in kbits per second.
|
253 |
+
- `penalty` (torch.Tensor): Commitment loss.
|
254 |
+
- `metrics` (dict): RVQ metrics, in particular rate of dead code replacement, and entropy.
|
255 |
+
"""
|
256 |
+
semantic_result = self.rvq_first(x, frame_rate)
|
257 |
+
if self.n_q == self.n_q_semantic:
|
258 |
+
return semantic_result
|
259 |
+
acoustic_result = self.rvq_rest(x, frame_rate)
|
260 |
+
full_quantized_emb = semantic_result.x + acoustic_result.x
|
261 |
+
full_quantized_codes = torch.cat(
|
262 |
+
[semantic_result.codes, acoustic_result.codes], dim=1
|
263 |
+
)
|
264 |
+
# This is the actual number of quantizers used, e.g. taking into account quantizer dropout.
|
265 |
+
n_q_semantic = semantic_result.codes.shape[1]
|
266 |
+
n_q_acoustic = acoustic_result.codes.shape[1]
|
267 |
+
full_quantized_bandwidth = semantic_result.bandwidth + acoustic_result.bandwidth
|
268 |
+
full_quantized_penalty = self._renorm_and_add(
|
269 |
+
semantic_result.penalty, acoustic_result.penalty, n_q_semantic, n_q_acoustic
|
270 |
+
)
|
271 |
+
full_quantized_metrics = semantic_result.metrics
|
272 |
+
for key, value in acoustic_result.metrics.items():
|
273 |
+
if key in full_quantized_metrics:
|
274 |
+
full_quantized_metrics[key] = self._renorm_and_add(
|
275 |
+
full_quantized_metrics[key], value, n_q_semantic, n_q_acoustic
|
276 |
+
)
|
277 |
+
else:
|
278 |
+
full_quantized_metrics[key] = value
|
279 |
+
return QuantizedResult(
|
280 |
+
full_quantized_emb,
|
281 |
+
full_quantized_codes,
|
282 |
+
full_quantized_bandwidth,
|
283 |
+
penalty=full_quantized_penalty,
|
284 |
+
metrics=full_quantized_metrics,
|
285 |
+
)
|
286 |
+
|
287 |
+
def encode(self, x: torch.Tensor) -> torch.Tensor:
|
288 |
+
"""Encode a given input tensor with the specified frame rate at the given bandwidth.
|
289 |
+
The RVQ encode method sets the appropriate number of quantizer to use
|
290 |
+
and returns indices for each quantizer.
|
291 |
+
"""
|
292 |
+
codes = self.rvq_first.encode(x)
|
293 |
+
if self.n_q > self.n_q_semantic:
|
294 |
+
acoustic_codes = self.rvq_rest.encode(x)
|
295 |
+
codes = torch.cat([codes, acoustic_codes], dim=1)
|
296 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
297 |
+
return codes
|
298 |
+
|
299 |
+
def decode(self, codes: torch.Tensor) -> torch.Tensor:
|
300 |
+
"""Decode the given codes to the quantized representation."""
|
301 |
+
# codes is [B, K, T], with T frames, K nb of codebooks.
|
302 |
+
quantized = self.rvq_first.decode(codes[:, : self.n_q_semantic])
|
303 |
+
if codes.shape[1] > self.n_q_semantic:
|
304 |
+
quantized += self.rvq_rest.decode(codes[:, self.n_q_semantic :])
|
305 |
+
return quantized
|
306 |
+
|
307 |
+
@property
|
308 |
+
def total_codebooks(self):
|
309 |
+
return self.rvq_first.max_n_q + self.rvq_rest.max_n_q
|
310 |
+
|
311 |
+
@property
|
312 |
+
def num_codebooks(self):
|
313 |
+
return self.rvq_first.num_codebooks + self.rvq_rest.num_codebooks
|
314 |
+
|
315 |
+
@property
|
316 |
+
def n_q(self):
|
317 |
+
return self.rvq_first.n_q + self.rvq_rest.n_q
|
318 |
+
|
319 |
+
@property
|
320 |
+
def dimension(self):
|
321 |
+
return self.rvq_first.dimension
|
322 |
+
|
323 |
+
@property
|
324 |
+
def semantic_quantizer(self) -> ResidualVectorQuantizer:
|
325 |
+
"""This returns the quantizer that models the first level of the hierarchy (typically semantic)."""
|
326 |
+
return self.rvq_first
|
327 |
+
|
328 |
+
@property
|
329 |
+
def acoustic_quantizer(self) -> ResidualVectorQuantizer:
|
330 |
+
"""This returns the quantizer that models the higher levels of the hierarchy (typically acoustic)."""
|
331 |
+
return self.rvq_rest
|
332 |
+
|
333 |
+
def set_num_codebooks(self, n: int):
|
334 |
+
assert n >= self.n_q_semantic and n <= self.total_codebooks
|
335 |
+
self.rvq_rest.set_num_codebooks(n - self.n_q_semantic)
|
336 |
+
|
337 |
+
@property
|
338 |
+
def cardinality(self) -> int:
|
339 |
+
assert self.rvq_rest.cardinality == self.rvq_first.cardinality
|
340 |
+
return self.rvq_first.cardinality
|
moshi/utils/__init__.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
"""Utilities."""
|
moshi/utils/autocast.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
import torch
|
12 |
+
|
13 |
+
|
14 |
+
class TorchAutocast:
|
15 |
+
"""TorchAutocast utility class.
|
16 |
+
Allows you to enable and disable autocast. This is specially useful
|
17 |
+
when dealing with different architectures and clusters with different
|
18 |
+
levels of support.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
enabled (bool): Whether to enable torch.autocast or not.
|
22 |
+
args: Additional args for torch.autocast.
|
23 |
+
kwargs: Additional kwargs for torch.autocast
|
24 |
+
"""
|
25 |
+
|
26 |
+
def __init__(self, enabled: bool, *args, **kwargs):
|
27 |
+
self.autocast = torch.autocast(*args, **kwargs) if enabled else None
|
28 |
+
|
29 |
+
def __enter__(self):
|
30 |
+
if self.autocast is None:
|
31 |
+
return
|
32 |
+
try:
|
33 |
+
self.autocast.__enter__()
|
34 |
+
except RuntimeError:
|
35 |
+
device = self.autocast.device
|
36 |
+
dtype = self.autocast.fast_dtype
|
37 |
+
raise RuntimeError(
|
38 |
+
f"There was an error autocasting with dtype={dtype} device={device}\n"
|
39 |
+
"If you are on the FAIR Cluster, you might need to use autocast_dtype=float16"
|
40 |
+
)
|
41 |
+
|
42 |
+
def __exit__(self, *args, **kwargs):
|
43 |
+
if self.autocast is None:
|
44 |
+
return
|
45 |
+
self.autocast.__exit__(*args, **kwargs)
|
moshi/utils/compile.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
"""
|
6 |
+
Provides some extra utilities around torch compile, in particular with a way
|
7 |
+
to fully deactivate it easily with a context manager.
|
8 |
+
Provides a simple activation checkpointing that is compatible with FSDP and torch compile.
|
9 |
+
Finally, provides some utilities for CUDA graphing functions.
|
10 |
+
"""
|
11 |
+
from contextlib import contextmanager
|
12 |
+
from functools import wraps
|
13 |
+
import inspect
|
14 |
+
import os
|
15 |
+
import typing as tp
|
16 |
+
|
17 |
+
import torch
|
18 |
+
from torch import cuda
|
19 |
+
|
20 |
+
|
21 |
+
_compile_disabled: bool = False
|
22 |
+
|
23 |
+
|
24 |
+
@contextmanager
|
25 |
+
def no_compile():
|
26 |
+
"""Disable torch.compile locally. Now Pytorch 2.4 provides a function to do that."""
|
27 |
+
global _compile_disabled
|
28 |
+
|
29 |
+
prev_disabled = _compile_disabled
|
30 |
+
_compile_disabled = True
|
31 |
+
try:
|
32 |
+
yield
|
33 |
+
finally:
|
34 |
+
_compile_disabled = prev_disabled
|
35 |
+
|
36 |
+
|
37 |
+
def torch_compile_lazy(fun):
|
38 |
+
"""torch.compile creates a huge pool of processes, even when not using the function at all,
|
39 |
+
e.g. with Dora. This can polute stderr when doing CTRL+C. So we do it in a lazy way.
|
40 |
+
"""
|
41 |
+
if os.environ.get("NO_TORCH_COMPILE"):
|
42 |
+
return fun
|
43 |
+
fun_compiled = None
|
44 |
+
|
45 |
+
@wraps(fun)
|
46 |
+
def _wrapped(*args, **kwargs):
|
47 |
+
nonlocal fun_compiled
|
48 |
+
if _compile_disabled:
|
49 |
+
return fun(*args, **kwargs)
|
50 |
+
if fun_compiled is None:
|
51 |
+
fun_compiled = torch.compile(fun)
|
52 |
+
return fun_compiled(*args, **kwargs)
|
53 |
+
|
54 |
+
return _wrapped
|
55 |
+
|
56 |
+
|
57 |
+
class Checkpoint(torch.autograd.Function):
|
58 |
+
@staticmethod
|
59 |
+
def forward(ctx, function, *args) -> tp.Any:
|
60 |
+
to_save = []
|
61 |
+
ctx.others = []
|
62 |
+
ctx.function = function
|
63 |
+
# Sources will indicate whether the arg in position N is
|
64 |
+
# a tensor stored in ctx.save_for_backward, or inside ctx.others.
|
65 |
+
ctx.sources = []
|
66 |
+
new_args = []
|
67 |
+
for arg in args:
|
68 |
+
if isinstance(arg, torch.Tensor):
|
69 |
+
to_save.append(arg)
|
70 |
+
ctx.sources.append("tensor")
|
71 |
+
new_args.append(arg.detach())
|
72 |
+
else:
|
73 |
+
ctx.sources.append("other")
|
74 |
+
ctx.others.append(arg)
|
75 |
+
new_args.append(arg)
|
76 |
+
ctx.save_for_backward(*to_save)
|
77 |
+
# During the forward, we just make a pass with no gradient computed.
|
78 |
+
with torch.no_grad():
|
79 |
+
res = function(*new_args)
|
80 |
+
return res
|
81 |
+
|
82 |
+
@staticmethod
|
83 |
+
def backward(ctx, *grads) -> tp.Tuple[tp.Optional[torch.Tensor], ...]:
|
84 |
+
pseudo_tensors = []
|
85 |
+
with torch.set_grad_enabled(True):
|
86 |
+
# We create leaf tensors to collect the output gradients.
|
87 |
+
# We call them pseudo_tensors because they are pretending to be the input
|
88 |
+
# to `function` but are not directly
|
89 |
+
for tensor in ctx.saved_tensors:
|
90 |
+
pseudo_tensor = tensor.detach()
|
91 |
+
pseudo_tensor.requires_grad_(True)
|
92 |
+
pseudo_tensors.append(pseudo_tensor)
|
93 |
+
pseudo_tensors_copy = list(pseudo_tensors)
|
94 |
+
args = []
|
95 |
+
for source in ctx.sources:
|
96 |
+
if source == "other":
|
97 |
+
args.append(ctx.others.pop(0))
|
98 |
+
else:
|
99 |
+
assert source == "tensor"
|
100 |
+
args.append(pseudo_tensors_copy.pop(0))
|
101 |
+
res = ctx.function(*args)
|
102 |
+
# The second forward with grad computation allows us to connect the input leaf tensors
|
103 |
+
# inside pseudo_tensors, to the outputs of the function called.
|
104 |
+
if not isinstance(res, tuple):
|
105 |
+
res = (res,)
|
106 |
+
# Now we just ask Torch to compute the derivative of `res` given the gradient coming from above
|
107 |
+
# `grads`. The computed gradient will end up into the `pseudo_tensors` grad attributes.
|
108 |
+
torch.autograd.backward(res, grads)
|
109 |
+
out: tp.List[tp.Optional[torch.Tensor]] = [None]
|
110 |
+
for source in ctx.sources:
|
111 |
+
# We still need to output `None` values for non tensor parameters.
|
112 |
+
if source == "other":
|
113 |
+
out.append(None)
|
114 |
+
else:
|
115 |
+
assert source == "tensor"
|
116 |
+
out.append(pseudo_tensors.pop(0).grad)
|
117 |
+
return tuple(out)
|
118 |
+
|
119 |
+
|
120 |
+
def simple_checkpoint(module: torch.nn.Module, *args, **kwargs):
|
121 |
+
"""Custom implementation of checkpointing in PyTorch as the builtin implementation is broken
|
122 |
+
when using torch compile. Only supports wrapping a `nn.Module` with a forward with no `*args` or `**kwargs`.
|
123 |
+
|
124 |
+
https://github.com/pytorch/pytorch/issues/97436.
|
125 |
+
Should be resolved in nightlies, but it is quite fun and simple to code it ourselves.
|
126 |
+
"""
|
127 |
+
if hasattr(module, "_fsdp_wrapped_module"):
|
128 |
+
module_for_sig = module._fsdp_wrapped_module
|
129 |
+
else:
|
130 |
+
module_for_sig = module
|
131 |
+
sig = inspect.signature(module_for_sig.forward)
|
132 |
+
# We first flatten all arguments to use only *args, to make things easier and because
|
133 |
+
# torch.autograd.Function has weird support for kwargs.
|
134 |
+
bounded = sig.bind(*args, **kwargs)
|
135 |
+
new_args = []
|
136 |
+
for name, param in sig.parameters.items():
|
137 |
+
if param.kind in {
|
138 |
+
inspect.Parameter.VAR_POSITIONAL,
|
139 |
+
inspect.Parameter.VAR_KEYWORD,
|
140 |
+
}:
|
141 |
+
raise RuntimeError("simple_checkpoint doesn't support var args.")
|
142 |
+
if name not in bounded.arguments:
|
143 |
+
break
|
144 |
+
new_args.append(bounded.arguments[name])
|
145 |
+
return Checkpoint.apply(module, *new_args)
|
146 |
+
|
147 |
+
|
148 |
+
_in_cuda_graph = False
|
149 |
+
_disable_cuda_graph = False
|
150 |
+
|
151 |
+
|
152 |
+
def in_cuda_graph() -> bool:
|
153 |
+
"""Indicate whether we are in a function that is CUDA Graphed (or will be soon)."""
|
154 |
+
return _in_cuda_graph
|
155 |
+
|
156 |
+
|
157 |
+
@contextmanager
|
158 |
+
def _set_in_cuda_graph():
|
159 |
+
global _in_cuda_graph
|
160 |
+
assert not _in_cuda_graph
|
161 |
+
_in_cuda_graph = True
|
162 |
+
try:
|
163 |
+
yield
|
164 |
+
finally:
|
165 |
+
_in_cuda_graph = False
|
166 |
+
|
167 |
+
|
168 |
+
def _is_cuda_graph_enabled() -> bool:
|
169 |
+
if _disable_cuda_graph:
|
170 |
+
return False
|
171 |
+
no_cuda_graph = os.environ.get("NO_CUDA_GRAPH", "")
|
172 |
+
if no_cuda_graph.lower() not in {"0", "no", "n", ""}:
|
173 |
+
return False
|
174 |
+
return True
|
175 |
+
|
176 |
+
|
177 |
+
@contextmanager
|
178 |
+
def no_cuda_graph():
|
179 |
+
"""Deactivate CUDA Graphing for all the calls in this context manager."""
|
180 |
+
global _disable_cuda_graph
|
181 |
+
old_value = _disable_cuda_graph
|
182 |
+
_disable_cuda_graph = True
|
183 |
+
try:
|
184 |
+
yield
|
185 |
+
finally:
|
186 |
+
_disable_cuda_graph = old_value
|
187 |
+
|
188 |
+
|
189 |
+
class CUDAGraphed:
|
190 |
+
"""Allow simple CUDA Graphing of a function.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
func: callable, taking any number of arguments. Its tensors arguments should
|
194 |
+
be top level args, not nested in structures (tuples, dicts, etc). Keyword
|
195 |
+
arguments are NOT supported for simplicity.
|
196 |
+
warmup_steps: how many call to make normally before CUDA Graphing. In particular, this
|
197 |
+
allows torch.compiled functions to get properly compiled.
|
198 |
+
disabled: if True, just call the func directly, useful to quickly deactivate on CPU.
|
199 |
+
"""
|
200 |
+
|
201 |
+
def __init__(self, func: tp.Callable, warmup_steps: int = 1, disable: bool = False):
|
202 |
+
self.func = func
|
203 |
+
self.warmup_steps = warmup_steps
|
204 |
+
self.disable = disable
|
205 |
+
self._graph: cuda.CUDAGraph | None = None
|
206 |
+
self._output: tuple | None = None
|
207 |
+
self._args: tuple | None = None
|
208 |
+
|
209 |
+
def reset(self, warmup_steps: int = 0) -> None:
|
210 |
+
"""Reset the state, meaning the next call we get CUDA Graphed again. Useful if some
|
211 |
+
shapes have changed, or external state (e.g. KVCache) has changed."""
|
212 |
+
self.warmup_steps = warmup_steps
|
213 |
+
self._graph = None
|
214 |
+
self._output = None
|
215 |
+
self._args = None
|
216 |
+
|
217 |
+
def __call__(self, *args, **kwargs) -> tp.Any:
|
218 |
+
if kwargs:
|
219 |
+
raise RuntimeError("Named arguments not supported for now.")
|
220 |
+
if self.disable or not _is_cuda_graph_enabled() or in_cuda_graph():
|
221 |
+
return self.func(*args, **kwargs)
|
222 |
+
|
223 |
+
def _clone_tensors(args: tuple) -> tuple:
|
224 |
+
out: list = []
|
225 |
+
for arg in args:
|
226 |
+
if isinstance(arg, torch.Tensor):
|
227 |
+
arg = arg.clone()
|
228 |
+
out.append(arg)
|
229 |
+
return tuple(out)
|
230 |
+
|
231 |
+
def _match_values_copy_tensors(args: tuple, target_args: tuple) -> None:
|
232 |
+
if len(args) != len(target_args):
|
233 |
+
raise ValueError(
|
234 |
+
f"Expected {len(target_args)}, but got {args} for CUDA Graphed function."
|
235 |
+
)
|
236 |
+
for idx, (source, target) in enumerate(zip(args, target_args)):
|
237 |
+
if isinstance(target, torch.Tensor):
|
238 |
+
if not isinstance(source, torch.Tensor):
|
239 |
+
raise ValueError(
|
240 |
+
f"Argument #{idx} was a tensor, and is no longer (now {source})."
|
241 |
+
)
|
242 |
+
if source.shape != target.shape:
|
243 |
+
raise ValueError(
|
244 |
+
f"Argument #{idx} had shape {target.shape}, but got shae {source.shape}"
|
245 |
+
)
|
246 |
+
target.copy_(source)
|
247 |
+
else:
|
248 |
+
if isinstance(source, torch.Tensor):
|
249 |
+
raise ValueError(
|
250 |
+
f"Argument #{idx} was not a tensor {target}, but is now one."
|
251 |
+
)
|
252 |
+
if source is not target and source != target:
|
253 |
+
raise ValueError(
|
254 |
+
f"Argument #{idx} changed value from {target} to {source}."
|
255 |
+
)
|
256 |
+
|
257 |
+
with _set_in_cuda_graph():
|
258 |
+
# Prevent any one under us to try and CUDA Graph things.
|
259 |
+
if self._graph is None:
|
260 |
+
if self.warmup_steps <= 0:
|
261 |
+
self._graph = cuda.CUDAGraph()
|
262 |
+
# Making a copy just to ensure those are not used else where.
|
263 |
+
self._args = _clone_tensors(args)
|
264 |
+
with cuda.graph(self._graph):
|
265 |
+
self._output = self.func(*self._args)
|
266 |
+
# At this point nothing really happened, so we have to make it run for real.
|
267 |
+
self._graph.replay()
|
268 |
+
return self._output
|
269 |
+
else:
|
270 |
+
self.warmup_steps -= 1
|
271 |
+
return self.func(*args)
|
272 |
+
else:
|
273 |
+
assert self._args is not None
|
274 |
+
assert self._output is not None
|
275 |
+
_match_values_copy_tensors(args, self._args)
|
276 |
+
self._graph.replay()
|
277 |
+
return self._output
|
278 |
+
|
279 |
+
|
280 |
+
def cuda_graph(func: tp.Callable, warmup_steps: int = 1):
|
281 |
+
"""Just calls `CUDAGraphed` on the given function."""
|
282 |
+
if not _is_cuda_graph_enabled():
|
283 |
+
return func
|
284 |
+
return CUDAGraphed(func, warmup_steps)
|
moshi/utils/sampling.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Kyutai, all rights reserved.
|
2 |
+
# This source code is licensed under the license found in the
|
3 |
+
# LICENSE file in the root directory of this source tree.
|
4 |
+
|
5 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
6 |
+
# All rights reserved.
|
7 |
+
#
|
8 |
+
# This source code is licensed under the license found in the
|
9 |
+
# LICENSE file in the root directory of this source tree.
|
10 |
+
|
11 |
+
|
12 |
+
import torch
|
13 |
+
|
14 |
+
|
15 |
+
def multinomial(
|
16 |
+
input: torch.Tensor, num_samples: int, replacement=False, *, generator=None
|
17 |
+
):
|
18 |
+
"""torch.multinomial with arbitrary number of dimensions, and number of candidates on the last dimension.
|
19 |
+
|
20 |
+
Args:
|
21 |
+
input (torch.Tensor): The input tensor containing probabilities.
|
22 |
+
num_samples (int): Number of samples to draw.
|
23 |
+
replacement (bool): Whether to draw with replacement or not.
|
24 |
+
Keywords args:
|
25 |
+
generator (torch.Generator): A pseudorandom number generator for sampling.
|
26 |
+
Returns:
|
27 |
+
torch.Tensor: Last dimension contains num_samples indices
|
28 |
+
sampled from the multinomial probability distribution
|
29 |
+
located in the last dimension of tensor input.
|
30 |
+
"""
|
31 |
+
input_ = input.reshape(-1, input.shape[-1])
|
32 |
+
# We should probably be able to remove this once the following PR has landed:
|
33 |
+
# https://github.com/pytorch/pytorch/pull/134818/files
|
34 |
+
# In the meantime, we specialize the case no-replacement, nsamples=1 so as not
|
35 |
+
# to have a synchronization point.
|
36 |
+
if replacement or num_samples != 1:
|
37 |
+
output_ = torch.multinomial(
|
38 |
+
input_,
|
39 |
+
num_samples=num_samples,
|
40 |
+
replacement=replacement,
|
41 |
+
generator=generator,
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
q = torch.empty_like(input_).exponential_(1, generator=generator)
|
45 |
+
q = input_ / q
|
46 |
+
output_ = q.argmax(dim=-1, keepdim=True)
|
47 |
+
output = output_.reshape(*list(input.shape[:-1]), -1)
|
48 |
+
return output
|
49 |
+
|
50 |
+
|
51 |
+
def sample_top_k(probs: torch.Tensor, k: int) -> torch.Tensor:
|
52 |
+
"""Sample next token from top K values along the last dimension of the input probs tensor.
|
53 |
+
|
54 |
+
Args:
|
55 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
56 |
+
k (int): The k in “top-k”.
|
57 |
+
Returns:
|
58 |
+
torch.Tensor: Sampled tokens.
|
59 |
+
"""
|
60 |
+
probs, indices = torch.topk(probs, k, dim=-1)
|
61 |
+
next_token = multinomial(probs, num_samples=1)
|
62 |
+
next_token = indices.gather(-1, next_token)
|
63 |
+
return next_token
|
64 |
+
|
65 |
+
|
66 |
+
def sample_top_p(probs: torch.Tensor, p: float) -> torch.Tensor:
|
67 |
+
"""Sample next token from top P probabilities along the last dimension of the input probs tensor.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
probs (torch.Tensor): Input probabilities with token candidates on the last dimension.
|
71 |
+
p (int): The p in “top-p”.
|
72 |
+
Returns:
|
73 |
+
torch.Tensor: Sampled tokens.
|
74 |
+
"""
|
75 |
+
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
76 |
+
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
77 |
+
mask = probs_sum - probs_sort > p
|
78 |
+
probs_sort *= (~mask).float()
|
79 |
+
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
80 |
+
next_token = multinomial(probs_sort, num_samples=1)
|
81 |
+
next_token = torch.gather(probs_idx, -1, next_token)
|
82 |
+
return next_token
|
83 |
+
|
84 |
+
|
85 |
+
def sample_token(
|
86 |
+
logits: torch.Tensor,
|
87 |
+
use_sampling: bool = False,
|
88 |
+
temp: float = 1.0,
|
89 |
+
top_k: int = 0,
|
90 |
+
top_p: float = 0.0,
|
91 |
+
) -> torch.Tensor:
|
92 |
+
"""Given logits of shape [*, Card], returns a LongTensor of shape [*]."""
|
93 |
+
# Apply softmax for sampling if temp > 0. Else, do greedy sampling to avoid zero division error.
|
94 |
+
if use_sampling and temp > 0.0:
|
95 |
+
probs = torch.softmax(logits / temp, dim=-1)
|
96 |
+
if top_p > 0.0:
|
97 |
+
next_token = sample_top_p(probs, p=top_p)
|
98 |
+
elif top_k > 0:
|
99 |
+
next_token = sample_top_k(probs, k=top_k)
|
100 |
+
else:
|
101 |
+
next_token = multinomial(probs, num_samples=1)
|
102 |
+
else:
|
103 |
+
next_token = torch.argmax(logits, dim=-1, keepdim=True)
|
104 |
+
assert next_token.shape[-1] == 1
|
105 |
+
return next_token[..., 0]
|
106 |
+
|
107 |
+
|
108 |
+
if __name__ == "__main__":
|
109 |
+
torch.manual_seed(1234)
|
110 |
+
device = "cpu"
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
torch.backends.cuda.matmul.allow_tf32 = False
|
113 |
+
torch.backends.cudnn.allow_tf32 = False
|
114 |
+
device = "cuda:0"
|
115 |
+
|
116 |
+
ps = torch.tensor([5.0, 2.0, 12.0, 6.0, 8.0, 1.0, 0.0, 4.0], device=device)
|
117 |
+
cnts = torch.zeros(ps.shape, dtype=torch.long, device=device)
|
118 |
+
total_samples = 1000
|
119 |
+
for _ in range(total_samples):
|
120 |
+
vs = multinomial(ps, num_samples=1, replacement=False)
|
121 |
+
cnts[vs] += 1
|
122 |
+
diff = cnts / cnts.sum() - ps / ps.sum()
|
123 |
+
max_diff = diff.abs().max().cpu().item()
|
124 |
+
print(ps / ps.sum())
|
125 |
+
print(cnts / cnts.sum())
|
126 |
+
assert max_diff < 1.5e-2
|
pyproject.toml
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[build-system]
|
2 |
+
requires = ["setuptools>=61.0"]
|
3 |
+
build-backend = "setuptools.build_meta"
|
4 |
+
|
5 |
+
[project]
|
6 |
+
name = "moshi-hf"
|
7 |
+
version = "0.1.0"
|
8 |
+
description = "Moshi HuggingFace inference server"
|
9 |
+
requires-python = ">=3.8"
|
10 |
+
dependencies = [
|
11 |
+
"torch==2.4.1",
|
12 |
+
"torchaudio==2.4.1",
|
13 |
+
"torchvision==0.19.1",
|
14 |
+
"torchdata==0.10.0",
|
15 |
+
"transformers==4.46.3",
|
16 |
+
"huggingface-hub==0.27.1",
|
17 |
+
"safetensors==0.5.1",
|
18 |
+
"accelerate>=0.20.0",
|
19 |
+
"datasets==3.1.0",
|
20 |
+
"requests==2.32.3",
|
21 |
+
"urllib3==2.2.3",
|
22 |
+
"pyyaml==6.0.2",
|
23 |
+
"einops>=0.6.1",
|
24 |
+
"ulid-py==1.1.0",
|
25 |
+
"tqdm>=4.65.0",
|
26 |
+
"sentencepiece==0.2.0",
|
27 |
+
"jiwer==3.0.5",
|
28 |
+
"numpy==1.24.4",
|
29 |
+
"pandas==2.0.3",
|
30 |
+
"scikit-learn==1.5.2",
|
31 |
+
"pydub==0.25.1",
|
32 |
+
"librosa==0.10.2.post1",
|
33 |
+
"pyannote.audio==3.1.1",
|
34 |
+
"pesq==0.0.4",
|
35 |
+
"torchmetrics==1.6.0",
|
36 |
+
"uvicorn==0.25.0",
|
37 |
+
"fastapi==0.104.1",
|
38 |
+
"pydantic==2.5.2"
|
39 |
+
]
|
requirements.txt
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# PyTorch ecosystem with fixed versions
|
2 |
+
torch==2.4.1
|
3 |
+
torchaudio==2.4.1
|
4 |
+
torchvision==0.19.1
|
5 |
+
torchdata==0.10.0
|
6 |
+
|
7 |
+
# HuggingFace ecosystem
|
8 |
+
transformers==4.46.3
|
9 |
+
huggingface-hub==0.27.1
|
10 |
+
safetensors==0.5.1
|
11 |
+
accelerate>=0.20.0
|
12 |
+
datasets==3.1.0
|
13 |
+
|
14 |
+
# HTTP and utils
|
15 |
+
requests==2.32.3
|
16 |
+
urllib3==2.2.3
|
17 |
+
pyyaml==6.0.2
|
18 |
+
einops>=0.6.1
|
19 |
+
ulid-py==1.1.0
|
20 |
+
tqdm>=4.65.0
|
21 |
+
|
22 |
+
# NLP and text processing
|
23 |
+
sentencepiece==0.2.0
|
24 |
+
jiwer==3.0.5
|
25 |
+
|
26 |
+
# Data processing
|
27 |
+
numpy==1.24.4
|
28 |
+
pandas==2.0.3
|
29 |
+
scikit-learn==1.5.2
|
30 |
+
|
31 |
+
# Audio processing
|
32 |
+
pydub==0.25.1
|
33 |
+
librosa==0.10.2.post1
|
34 |
+
pyannote.audio==3.1.1
|
35 |
+
pesq==0.0.4
|
36 |
+
|
37 |
+
# Metrics and evaluation
|
38 |
+
torchmetrics==1.6.0
|
39 |
+
|
40 |
+
# Server
|
41 |
+
uvicorn==0.25.0
|
42 |
+
fastapi==0.104.1
|
43 |
+
pydantic==2.5.2
|
server.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from fastapi import FastAPI, HTTPException
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
from pydantic import BaseModel
|
5 |
+
import base64
|
6 |
+
import io
|
7 |
+
import os
|
8 |
+
import logging
|
9 |
+
from pathlib import Path
|
10 |
+
from inference import InferenceRecipe
|
11 |
+
from fastapi.middleware.cors import CORSMiddleware
|
12 |
+
|
13 |
+
logging.basicConfig(level=logging.INFO)
|
14 |
+
logger = logging.getLogger(__name__)
|
15 |
+
|
16 |
+
app = FastAPI()
|
17 |
+
|
18 |
+
# Add CORS middleware
|
19 |
+
app.add_middleware(
|
20 |
+
CORSMiddleware,
|
21 |
+
allow_origins=["*"],
|
22 |
+
allow_credentials=True,
|
23 |
+
allow_methods=["*"],
|
24 |
+
allow_headers=["*"],
|
25 |
+
)
|
26 |
+
|
27 |
+
class AudioRequest(BaseModel):
|
28 |
+
audio_data: str
|
29 |
+
sample_rate: int
|
30 |
+
|
31 |
+
class AudioResponse(BaseModel):
|
32 |
+
audio_data: str
|
33 |
+
text: str = ""
|
34 |
+
|
35 |
+
# Model initialization status
|
36 |
+
INITIALIZATION_STATUS = {
|
37 |
+
"model_loaded": False,
|
38 |
+
"error": None
|
39 |
+
}
|
40 |
+
|
41 |
+
# Global model instance
|
42 |
+
model = None
|
43 |
+
|
44 |
+
def initialize_model():
|
45 |
+
"""Initialize the model from mounted directory"""
|
46 |
+
global model, INITIALIZATION_STATUS
|
47 |
+
try:
|
48 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
49 |
+
logger.info(f"Initializing model on device: {device}")
|
50 |
+
|
51 |
+
model_path = os.getenv("MODEL_PATH", "/app/models")
|
52 |
+
if not os.path.exists(model_path):
|
53 |
+
raise RuntimeError(f"Model path {model_path} does not exist")
|
54 |
+
|
55 |
+
model = InferenceRecipe(model_path, device=device)
|
56 |
+
INITIALIZATION_STATUS["model_loaded"] = True
|
57 |
+
logger.info("Model initialized successfully")
|
58 |
+
return True
|
59 |
+
except Exception as e:
|
60 |
+
INITIALIZATION_STATUS["error"] = str(e)
|
61 |
+
logger.error(f"Failed to initialize model: {e}")
|
62 |
+
return False
|
63 |
+
|
64 |
+
@app.on_event("startup")
|
65 |
+
async def startup_event():
|
66 |
+
"""Initialize model on startup"""
|
67 |
+
initialize_model()
|
68 |
+
|
69 |
+
@app.get("/api/v1/health")
|
70 |
+
def health_check():
|
71 |
+
"""Health check endpoint"""
|
72 |
+
status = {
|
73 |
+
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
|
74 |
+
"gpu_available": torch.cuda.is_available(),
|
75 |
+
"initialization_status": INITIALIZATION_STATUS
|
76 |
+
}
|
77 |
+
|
78 |
+
if model is not None:
|
79 |
+
status.update({
|
80 |
+
"device": str(model.device),
|
81 |
+
"model_path": str(model.model_path),
|
82 |
+
"mimi_loaded": model.mimi is not None,
|
83 |
+
"tokenizer_loaded": model.text_tokenizer is not None,
|
84 |
+
"lm_loaded": model.lm_gen is not None
|
85 |
+
})
|
86 |
+
|
87 |
+
return status
|
88 |
+
|
89 |
+
@app.post("/api/v1/inference")
|
90 |
+
async def inference(request: AudioRequest) -> AudioResponse:
|
91 |
+
"""Run inference on audio input"""
|
92 |
+
if not INITIALIZATION_STATUS["model_loaded"]:
|
93 |
+
raise HTTPException(
|
94 |
+
status_code=503,
|
95 |
+
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
|
96 |
+
)
|
97 |
+
|
98 |
+
try:
|
99 |
+
# Decode audio from base64
|
100 |
+
audio_bytes = base64.b64decode(request.audio_data)
|
101 |
+
audio_array = np.load(io.BytesIO(audio_bytes))
|
102 |
+
|
103 |
+
# Run inference
|
104 |
+
result = model.inference(audio_array, request.sample_rate)
|
105 |
+
|
106 |
+
# Encode output audio
|
107 |
+
buffer = io.BytesIO()
|
108 |
+
np.save(buffer, result['audio'])
|
109 |
+
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
|
110 |
+
|
111 |
+
return AudioResponse(
|
112 |
+
audio_data=audio_b64,
|
113 |
+
text=result.get("text", "")
|
114 |
+
)
|
115 |
+
except Exception as e:
|
116 |
+
logger.error(f"Inference failed: {str(e)}")
|
117 |
+
raise HTTPException(status_code=500, detail=str(e))
|
118 |
+
|
119 |
+
if __name__ == "__main__":
|
120 |
+
import uvicorn
|
121 |
+
uvicorn.run(app, host="0.0.0.0", port=8000)
|
setup.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# The MIT License (MIT)
|
2 |
+
# Copyright © 2023 Yuma Rao
|
3 |
+
# Copyright © 2024 Omega Labs, Inc.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated
|
6 |
+
# documentation files (the “Software”), to deal in the Software without restriction, including without limitation
|
7 |
+
# the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software,
|
8 |
+
# and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
|
9 |
+
|
10 |
+
# The above copyright notice and this permission notice shall be included in all copies or substantial portions of
|
11 |
+
# the Software.
|
12 |
+
|
13 |
+
# THE SOFTWARE IS PROVIDED “AS IS”, WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO
|
14 |
+
# THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL
|
15 |
+
# THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION
|
16 |
+
# OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER
|
17 |
+
# DEALINGS IN THE SOFTWARE.
|
18 |
+
|
19 |
+
import re
|
20 |
+
import os
|
21 |
+
import codecs
|
22 |
+
from os import path
|
23 |
+
from io import open
|
24 |
+
from setuptools import setup, find_packages
|
25 |
+
|
26 |
+
|
27 |
+
def read_requirements(path):
|
28 |
+
with open(path, "r") as f:
|
29 |
+
requirements = f.read().splitlines()
|
30 |
+
return requirements
|
31 |
+
|
32 |
+
|
33 |
+
requirements = read_requirements("requirements.txt")
|
34 |
+
here = path.abspath(path.dirname(__file__))
|
35 |
+
|
36 |
+
with open(path.join(here, "README.md"), encoding="utf-8") as f:
|
37 |
+
long_description = f.read()
|
38 |
+
|
39 |
+
# loading version from setup.py
|
40 |
+
with codecs.open(
|
41 |
+
os.path.join(here, "template/__init__.py"), encoding="utf-8"
|
42 |
+
) as init_file:
|
43 |
+
version_match = re.search(
|
44 |
+
r"^__version__ = ['\"]([^'\"]*)['\"]", init_file.read(), re.M
|
45 |
+
)
|
46 |
+
version_string = version_match.group(1)
|
47 |
+
|
48 |
+
setup(
|
49 |
+
name="omegalabs-anytoany-bittensor",
|
50 |
+
version=version_string,
|
51 |
+
description="omegalabs-anytoany-bittensor",
|
52 |
+
long_description=long_description,
|
53 |
+
long_description_content_type="text/markdown",
|
54 |
+
url="https://github.com/omegalabsinc/omegalabs-anytoany-bittensor",
|
55 |
+
author="OMEGA Labs, Inc.",
|
56 |
+
packages=find_packages(),
|
57 |
+
include_package_data=True,
|
58 |
+
author_email="[email protected]",
|
59 |
+
license="MIT",
|
60 |
+
python_requires=">=3.8",
|
61 |
+
install_requires=requirements,
|
62 |
+
classifiers=[
|
63 |
+
"Development Status :: 3 - Alpha",
|
64 |
+
"Intended Audience :: Developers",
|
65 |
+
"Topic :: Software Development :: Build Tools",
|
66 |
+
"License :: OSI Approved :: MIT License",
|
67 |
+
"Programming Language :: Python :: 3 :: Only",
|
68 |
+
"Programming Language :: Python :: 3.8",
|
69 |
+
"Programming Language :: Python :: 3.9",
|
70 |
+
"Programming Language :: Python :: 3.10",
|
71 |
+
"Topic :: Scientific/Engineering",
|
72 |
+
"Topic :: Scientific/Engineering :: Mathematics",
|
73 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
74 |
+
"Topic :: Software Development",
|
75 |
+
"Topic :: Software Development :: Libraries",
|
76 |
+
"Topic :: Software Development :: Libraries :: Python Modules",
|
77 |
+
],
|
78 |
+
)
|