MartsoBodziu1994 commited on
Commit
2337a8e
·
verified ·
1 Parent(s): 06ceabb

Upload 7 files

Browse files
Files changed (7) hide show
  1. __init__.py +2 -0
  2. __main__.py +3 -0
  3. api.py +125 -0
  4. cli.py +71 -0
  5. generation.py +820 -0
  6. model.py +218 -0
  7. model_fine.py +149 -0
__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .api import generate_audio, text_to_semantic, semantic_to_waveform, save_as_prompt
2
+ from .generation import SAMPLE_RATE, preload_models
__main__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .cli import cli
2
+
3
+ cli()
api.py ADDED
@@ -0,0 +1,125 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Optional, Union
2
+
3
+ import numpy as np
4
+
5
+ from .generation import codec_decode, generate_coarse, generate_fine, generate_text_semantic
6
+
7
+
8
+ def text_to_semantic(
9
+ text: str,
10
+ history_prompt: Optional[Union[Dict, str]] = None,
11
+ temp: float = 0.7,
12
+ silent: bool = False,
13
+ ):
14
+ """Generate semantic array from text.
15
+
16
+ Args:
17
+ text: text to be turned into audio
18
+ history_prompt: history choice for audio cloning
19
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
20
+ silent: disable progress bar
21
+
22
+ Returns:
23
+ numpy semantic array to be fed into `semantic_to_waveform`
24
+ """
25
+ x_semantic = generate_text_semantic(
26
+ text,
27
+ history_prompt=history_prompt,
28
+ temp=temp,
29
+ silent=silent,
30
+ use_kv_caching=True
31
+ )
32
+ return x_semantic
33
+
34
+
35
+ def semantic_to_waveform(
36
+ semantic_tokens: np.ndarray,
37
+ history_prompt: Optional[Union[Dict, str]] = None,
38
+ temp: float = 0.7,
39
+ silent: bool = False,
40
+ output_full: bool = False,
41
+ ):
42
+ """Generate audio array from semantic input.
43
+
44
+ Args:
45
+ semantic_tokens: semantic token output from `text_to_semantic`
46
+ history_prompt: history choice for audio cloning
47
+ temp: generation temperature (1.0 more diverse, 0.0 more conservative)
48
+ silent: disable progress bar
49
+ output_full: return full generation to be used as a history prompt
50
+
51
+ Returns:
52
+ numpy audio array at sample frequency 24khz
53
+ """
54
+ coarse_tokens = generate_coarse(
55
+ semantic_tokens,
56
+ history_prompt=history_prompt,
57
+ temp=temp,
58
+ silent=silent,
59
+ use_kv_caching=True
60
+ )
61
+ fine_tokens = generate_fine(
62
+ coarse_tokens,
63
+ history_prompt=history_prompt,
64
+ temp=0.5,
65
+ )
66
+ audio_arr = codec_decode(fine_tokens)
67
+ if output_full:
68
+ full_generation = {
69
+ "semantic_prompt": semantic_tokens,
70
+ "coarse_prompt": coarse_tokens,
71
+ "fine_prompt": fine_tokens,
72
+ }
73
+ return full_generation, audio_arr
74
+ return audio_arr
75
+
76
+
77
+ def save_as_prompt(filepath, full_generation):
78
+ assert(filepath.endswith(".npz"))
79
+ assert(isinstance(full_generation, dict))
80
+ assert("semantic_prompt" in full_generation)
81
+ assert("coarse_prompt" in full_generation)
82
+ assert("fine_prompt" in full_generation)
83
+ np.savez(filepath, **full_generation)
84
+
85
+
86
+ def generate_audio(
87
+ text: str,
88
+ history_prompt: Optional[Union[Dict, str]] = None,
89
+ text_temp: float = 0.7,
90
+ waveform_temp: float = 0.7,
91
+ silent: bool = False,
92
+ output_full: bool = False,
93
+ ):
94
+ """Generate audio array from input text.
95
+
96
+ Args:
97
+ text: text to be turned into audio
98
+ history_prompt: history choice for audio cloning
99
+ text_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
100
+ waveform_temp: generation temperature (1.0 more diverse, 0.0 more conservative)
101
+ silent: disable progress bar
102
+ output_full: return full generation to be used as a history prompt
103
+
104
+ Returns:
105
+ numpy audio array at sample frequency 24khz
106
+ """
107
+ semantic_tokens = text_to_semantic(
108
+ text,
109
+ history_prompt=history_prompt,
110
+ temp=text_temp,
111
+ silent=silent,
112
+ )
113
+ out = semantic_to_waveform(
114
+ semantic_tokens,
115
+ history_prompt=history_prompt,
116
+ temp=waveform_temp,
117
+ silent=silent,
118
+ output_full=output_full,
119
+ )
120
+ if output_full:
121
+ full_generation, audio_arr = out
122
+ return full_generation, audio_arr
123
+ else:
124
+ audio_arr = out
125
+ return audio_arr
cli.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from typing import Dict, Optional, Union
3
+ import os
4
+
5
+ from scipy.io.wavfile import write as write_wav
6
+ from .api import generate_audio
7
+ from .generation import SAMPLE_RATE
8
+
9
+
10
+ def cli():
11
+ """Commandline interface."""
12
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
13
+ parser.add_argument("--text", type=str, help="text to be turned into audio")
14
+ parser.add_argument(
15
+ "--output_filename",
16
+ type=str,
17
+ default="bark_generation.wav",
18
+ help="output audio file name",
19
+ )
20
+ parser.add_argument("--output_dir", type=str, default=".", help="directory to save the outputs")
21
+ parser.add_argument(
22
+ "--history_prompt",
23
+ type=str,
24
+ default=None,
25
+ help="history choice for audio cloning, be path to the .npz file.",
26
+ )
27
+ parser.add_argument(
28
+ "--text_temp",
29
+ default=0.7,
30
+ type=float,
31
+ help="generation temperature (1.0 more diverse, 0.0 more conservative)",
32
+ )
33
+ parser.add_argument(
34
+ "--waveform_temp",
35
+ default=0.7,
36
+ type=float,
37
+ help="generation temperature (1.0 more diverse, 0.0 more conservative)",
38
+ )
39
+ parser.add_argument("--silent", default=False, type=bool, help="disable progress bar")
40
+ parser.add_argument(
41
+ "--output_full",
42
+ default=False,
43
+ type=bool,
44
+ help="return full generation to be used as a history prompt",
45
+ )
46
+
47
+ args = vars(parser.parse_args())
48
+ input_text: str = args.get("text")
49
+ output_filename: str = args.get("output_filename")
50
+ output_dir: str = args.get("output_dir")
51
+ history_prompt: str = args.get("history_prompt")
52
+ text_temp: float = args.get("text_temp")
53
+ waveform_temp: float = args.get("waveform_temp")
54
+ silent: bool = args.get("silent")
55
+ output_full: bool = args.get("output_full")
56
+
57
+ try:
58
+ os.makedirs(output_dir, exist_ok=True)
59
+ generated_audio = generate_audio(
60
+ input_text,
61
+ history_prompt=history_prompt,
62
+ text_temp=text_temp,
63
+ waveform_temp=waveform_temp,
64
+ silent=silent,
65
+ output_full=output_full,
66
+ )
67
+ output_file_path = os.path.join(output_dir, output_filename)
68
+ write_wav(output_file_path, SAMPLE_RATE, generated_audio)
69
+ print(f"Done! Output audio file is saved at: '{output_file_path}'")
70
+ except Exception as e:
71
+ print(f"Oops, an error occurred: {e}")
generation.py ADDED
@@ -0,0 +1,820 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import contextlib
2
+ import gc
3
+ import os
4
+ import re
5
+
6
+ from encodec import EncodecModel
7
+ import funcy
8
+ import logging
9
+ import numpy as np
10
+ from scipy.special import softmax
11
+ import torch
12
+ import torch.nn.functional as F
13
+ import tqdm
14
+ from transformers import BertTokenizer
15
+ from huggingface_hub import hf_hub_download
16
+
17
+ from .model import GPTConfig, GPT
18
+ from .model_fine import FineGPT, FineGPTConfig
19
+
20
+ if (
21
+ torch.cuda.is_available() and
22
+ hasattr(torch.cuda, "amp") and
23
+ hasattr(torch.cuda.amp, "autocast") and
24
+ hasattr(torch.cuda, "is_bf16_supported") and
25
+ torch.cuda.is_bf16_supported()
26
+ ):
27
+ autocast = funcy.partial(torch.cuda.amp.autocast, dtype=torch.bfloat16)
28
+ else:
29
+ @contextlib.contextmanager
30
+ def autocast():
31
+ yield
32
+
33
+
34
+ # hold models in global scope to lazy load
35
+ global models
36
+ models = {}
37
+
38
+ global models_devices
39
+ models_devices = {}
40
+
41
+
42
+ CONTEXT_WINDOW_SIZE = 1024
43
+
44
+ SEMANTIC_RATE_HZ = 49.9
45
+ SEMANTIC_VOCAB_SIZE = 10_000
46
+
47
+ CODEBOOK_SIZE = 1024
48
+ N_COARSE_CODEBOOKS = 2
49
+ N_FINE_CODEBOOKS = 8
50
+ COARSE_RATE_HZ = 75
51
+
52
+ SAMPLE_RATE = 24_000
53
+
54
+
55
+ SUPPORTED_LANGS = [
56
+ ("English", "en"),
57
+ ("German", "de"),
58
+ ("Spanish", "es"),
59
+ ("French", "fr"),
60
+ ("Hindi", "hi"),
61
+ ("Italian", "it"),
62
+ ("Japanese", "ja"),
63
+ ("Korean", "ko"),
64
+ ("Polish", "pl"),
65
+ ("Portuguese", "pt"),
66
+ ("Russian", "ru"),
67
+ ("Turkish", "tr"),
68
+ ("Chinese", "zh"),
69
+ ]
70
+
71
+ ALLOWED_PROMPTS = {"announcer"}
72
+ for _, lang in SUPPORTED_LANGS:
73
+ for prefix in ("", f"v2{os.path.sep}"):
74
+ for n in range(10):
75
+ ALLOWED_PROMPTS.add(f"{prefix}{lang}_speaker_{n}")
76
+
77
+
78
+ logger = logging.getLogger(__name__)
79
+
80
+
81
+ CUR_PATH = os.path.dirname(os.path.abspath(__file__))
82
+
83
+
84
+ default_cache_dir = os.path.join(os.path.expanduser("~"), ".cache")
85
+ CACHE_DIR = os.path.join(os.getenv("XDG_CACHE_HOME", default_cache_dir), "suno", "bark_v0")
86
+
87
+
88
+ def _cast_bool_env_var(s):
89
+ return s.lower() in ('true', '1', 't')
90
+
91
+
92
+ USE_SMALL_MODELS = _cast_bool_env_var(os.environ.get("SUNO_USE_SMALL_MODELS", "False"))
93
+ GLOBAL_ENABLE_MPS = _cast_bool_env_var(os.environ.get("SUNO_ENABLE_MPS", "False"))
94
+ OFFLOAD_CPU = _cast_bool_env_var(os.environ.get("SUNO_OFFLOAD_CPU", "False"))
95
+
96
+
97
+ REMOTE_MODEL_PATHS = {
98
+ "text_small": {
99
+ "repo_id": "suno/bark",
100
+ "file_name": "text.pt",
101
+ },
102
+ "coarse_small": {
103
+ "repo_id": "suno/bark",
104
+ "file_name": "coarse.pt",
105
+ },
106
+ "fine_small": {
107
+ "repo_id": "suno/bark",
108
+ "file_name": "fine.pt",
109
+ },
110
+ "text": {
111
+ "repo_id": "suno/bark",
112
+ "file_name": "text_2.pt",
113
+ },
114
+ "coarse": {
115
+ "repo_id": "suno/bark",
116
+ "file_name": "coarse_2.pt",
117
+ },
118
+ "fine": {
119
+ "repo_id": "suno/bark",
120
+ "file_name": "fine_2.pt",
121
+ },
122
+ }
123
+
124
+
125
+ if not hasattr(torch.nn.functional, 'scaled_dot_product_attention') and torch.cuda.is_available():
126
+ logger.warning(
127
+ "torch version does not support flash attention. You will get faster" +
128
+ " inference speed by upgrade torch to newest nightly version."
129
+ )
130
+
131
+
132
+ def _grab_best_device(use_gpu=True):
133
+ if torch.cuda.device_count() > 0 and use_gpu:
134
+ device = "cuda"
135
+ elif torch.backends.mps.is_available() and use_gpu and GLOBAL_ENABLE_MPS:
136
+ device = "mps"
137
+ else:
138
+ device = "cpu"
139
+ return device
140
+
141
+
142
+ def _get_ckpt_path(model_type, use_small=False):
143
+ key = model_type
144
+ if use_small or USE_SMALL_MODELS:
145
+ key += "_small"
146
+ return os.path.join(CACHE_DIR, REMOTE_MODEL_PATHS[key]["file_name"])
147
+
148
+
149
+ def _download(from_hf_path, file_name):
150
+ os.makedirs(CACHE_DIR, exist_ok=True)
151
+ hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=CACHE_DIR)
152
+
153
+
154
+ class InferenceContext:
155
+ def __init__(self, benchmark=False):
156
+ # we can't expect inputs to be the same length, so disable benchmarking by default
157
+ self._chosen_cudnn_benchmark = benchmark
158
+ self._cudnn_benchmark = None
159
+
160
+ def __enter__(self):
161
+ self._cudnn_benchmark = torch.backends.cudnn.benchmark
162
+ torch.backends.cudnn.benchmark = self._chosen_cudnn_benchmark
163
+
164
+ def __exit__(self, exc_type, exc_value, exc_traceback):
165
+ torch.backends.cudnn.benchmark = self._cudnn_benchmark
166
+
167
+
168
+ if torch.cuda.is_available():
169
+ torch.backends.cuda.matmul.allow_tf32 = True
170
+ torch.backends.cudnn.allow_tf32 = True
171
+
172
+
173
+ @contextlib.contextmanager
174
+ def _inference_mode():
175
+ with InferenceContext(), torch.inference_mode(), torch.no_grad(), autocast():
176
+ yield
177
+
178
+
179
+ def _clear_cuda_cache():
180
+ if torch.cuda.is_available():
181
+ torch.cuda.empty_cache()
182
+ torch.cuda.synchronize()
183
+
184
+
185
+ def clean_models(model_key=None):
186
+ global models
187
+ model_keys = [model_key] if model_key is not None else list(models.keys())
188
+ for k in model_keys:
189
+ if k in models:
190
+ del models[k]
191
+ _clear_cuda_cache()
192
+ gc.collect()
193
+
194
+
195
+ def _load_model(ckpt_path, device, use_small=False, model_type="text"):
196
+ if model_type == "text":
197
+ ConfigClass = GPTConfig
198
+ ModelClass = GPT
199
+ elif model_type == "coarse":
200
+ ConfigClass = GPTConfig
201
+ ModelClass = GPT
202
+ elif model_type == "fine":
203
+ ConfigClass = FineGPTConfig
204
+ ModelClass = FineGPT
205
+ else:
206
+ raise NotImplementedError()
207
+ model_key = f"{model_type}_small" if use_small or USE_SMALL_MODELS else model_type
208
+ model_info = REMOTE_MODEL_PATHS[model_key]
209
+ if not os.path.exists(ckpt_path):
210
+ logger.info(f"{model_type} model not found, downloading into `{CACHE_DIR}`.")
211
+ _download(model_info["repo_id"], model_info["file_name"])
212
+ checkpoint = torch.load(ckpt_path, map_location=device)
213
+ # this is a hack
214
+ model_args = checkpoint["model_args"]
215
+ if "input_vocab_size" not in model_args:
216
+ model_args["input_vocab_size"] = model_args["vocab_size"]
217
+ model_args["output_vocab_size"] = model_args["vocab_size"]
218
+ del model_args["vocab_size"]
219
+ gptconf = ConfigClass(**checkpoint["model_args"])
220
+ model = ModelClass(gptconf)
221
+ state_dict = checkpoint["model"]
222
+ # fixup checkpoint
223
+ unwanted_prefix = "_orig_mod."
224
+ for k, v in list(state_dict.items()):
225
+ if k.startswith(unwanted_prefix):
226
+ state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)
227
+ extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())
228
+ extra_keys = set([k for k in extra_keys if not k.endswith(".attn.bias")])
229
+ missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
230
+ missing_keys = set([k for k in missing_keys if not k.endswith(".attn.bias")])
231
+ if len(extra_keys) != 0:
232
+ raise ValueError(f"extra keys found: {extra_keys}")
233
+ if len(missing_keys) != 0:
234
+ raise ValueError(f"missing keys: {missing_keys}")
235
+ model.load_state_dict(state_dict, strict=False)
236
+ n_params = model.get_num_params()
237
+ val_loss = checkpoint["best_val_loss"].item()
238
+ logger.info(f"model loaded: {round(n_params/1e6,1)}M params, {round(val_loss,3)} loss")
239
+ model.eval()
240
+ model.to(device)
241
+ del checkpoint, state_dict
242
+ _clear_cuda_cache()
243
+ if model_type == "text":
244
+ tokenizer = BertTokenizer.from_pretrained("bert-base-multilingual-cased")
245
+ return {
246
+ "model": model,
247
+ "tokenizer": tokenizer,
248
+ }
249
+ return model
250
+
251
+
252
+ def _load_codec_model(device):
253
+ model = EncodecModel.encodec_model_24khz()
254
+ model.set_target_bandwidth(6.0)
255
+ model.eval()
256
+ model.to(device)
257
+ _clear_cuda_cache()
258
+ return model
259
+
260
+
261
+ def load_model(use_gpu=True, use_small=False, force_reload=False, model_type="text"):
262
+ _load_model_f = funcy.partial(_load_model, model_type=model_type, use_small=use_small)
263
+ if model_type not in ("text", "coarse", "fine"):
264
+ raise NotImplementedError()
265
+ global models
266
+ global models_devices
267
+ device = _grab_best_device(use_gpu=use_gpu)
268
+ model_key = f"{model_type}"
269
+ if OFFLOAD_CPU:
270
+ models_devices[model_key] = device
271
+ device = "cpu"
272
+ if model_key not in models or force_reload:
273
+ ckpt_path = _get_ckpt_path(model_type, use_small=use_small)
274
+ clean_models(model_key=model_key)
275
+ model = _load_model_f(ckpt_path, device)
276
+ models[model_key] = model
277
+ if model_type == "text":
278
+ models[model_key]["model"].to(device)
279
+ else:
280
+ models[model_key].to(device)
281
+ return models[model_key]
282
+
283
+
284
+ def load_codec_model(use_gpu=True, force_reload=False):
285
+ global models
286
+ global models_devices
287
+ device = _grab_best_device(use_gpu=use_gpu)
288
+ if device == "mps":
289
+ # encodec doesn't support mps
290
+ device = "cpu"
291
+ model_key = "codec"
292
+ if OFFLOAD_CPU:
293
+ models_devices[model_key] = device
294
+ device = "cpu"
295
+ if model_key not in models or force_reload:
296
+ clean_models(model_key=model_key)
297
+ model = _load_codec_model(device)
298
+ models[model_key] = model
299
+ models[model_key].to(device)
300
+ return models[model_key]
301
+
302
+
303
+ def preload_models(
304
+ text_use_gpu=True,
305
+ text_use_small=False,
306
+ coarse_use_gpu=True,
307
+ coarse_use_small=False,
308
+ fine_use_gpu=True,
309
+ fine_use_small=False,
310
+ codec_use_gpu=True,
311
+ force_reload=False,
312
+ ):
313
+ """Load all the necessary models for the pipeline."""
314
+ if _grab_best_device() == "cpu" and (
315
+ text_use_gpu or coarse_use_gpu or fine_use_gpu or codec_use_gpu
316
+ ):
317
+ logger.warning("No GPU being used. Careful, inference might be very slow!")
318
+ _ = load_model(
319
+ model_type="text", use_gpu=text_use_gpu, use_small=text_use_small, force_reload=force_reload
320
+ )
321
+ _ = load_model(
322
+ model_type="coarse",
323
+ use_gpu=coarse_use_gpu,
324
+ use_small=coarse_use_small,
325
+ force_reload=force_reload,
326
+ )
327
+ _ = load_model(
328
+ model_type="fine", use_gpu=fine_use_gpu, use_small=fine_use_small, force_reload=force_reload
329
+ )
330
+ _ = load_codec_model(use_gpu=codec_use_gpu, force_reload=force_reload)
331
+
332
+
333
+ ####
334
+ # Generation Functionality
335
+ ####
336
+
337
+
338
+ def _tokenize(tokenizer, text):
339
+ return tokenizer.encode(text, add_special_tokens=False)
340
+
341
+
342
+ def _detokenize(tokenizer, enc_text):
343
+ return tokenizer.decode(enc_text)
344
+
345
+
346
+ def _normalize_whitespace(text):
347
+ return re.sub(r"\s+", " ", text).strip()
348
+
349
+
350
+ TEXT_ENCODING_OFFSET = 10_048
351
+ SEMANTIC_PAD_TOKEN = 10_000
352
+ TEXT_PAD_TOKEN = 129_595
353
+ SEMANTIC_INFER_TOKEN = 129_599
354
+
355
+
356
+ def _load_history_prompt(history_prompt_input):
357
+ if isinstance(history_prompt_input, str) and history_prompt_input.endswith(".npz"):
358
+ history_prompt = np.load(history_prompt_input)
359
+ elif isinstance(history_prompt_input, str):
360
+ # make sure this works on non-ubuntu
361
+ history_prompt_input = os.path.join(*history_prompt_input.split("/"))
362
+ if history_prompt_input not in ALLOWED_PROMPTS:
363
+ raise ValueError("history prompt not found")
364
+ history_prompt = np.load(
365
+ os.path.join(CUR_PATH, "assets", "prompts", f"{history_prompt_input}.npz")
366
+ )
367
+ elif isinstance(history_prompt_input, dict):
368
+ assert("semantic_prompt" in history_prompt_input)
369
+ assert("coarse_prompt" in history_prompt_input)
370
+ assert("fine_prompt" in history_prompt_input)
371
+ history_prompt = history_prompt_input
372
+ else:
373
+ raise ValueError("history prompt format unrecognized")
374
+ return history_prompt
375
+
376
+
377
+ def generate_text_semantic(
378
+ text,
379
+ history_prompt=None,
380
+ temp=0.7,
381
+ top_k=None,
382
+ top_p=None,
383
+ silent=False,
384
+ min_eos_p=0.2,
385
+ max_gen_duration_s=None,
386
+ allow_early_stop=True,
387
+ use_kv_caching=False,
388
+ ):
389
+ """Generate semantic tokens from text."""
390
+ assert isinstance(text, str)
391
+ text = _normalize_whitespace(text)
392
+ assert len(text.strip()) > 0
393
+ if history_prompt is not None:
394
+ history_prompt = _load_history_prompt(history_prompt)
395
+ semantic_history = history_prompt["semantic_prompt"]
396
+ assert (
397
+ isinstance(semantic_history, np.ndarray)
398
+ and len(semantic_history.shape) == 1
399
+ and len(semantic_history) > 0
400
+ and semantic_history.min() >= 0
401
+ and semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
402
+ )
403
+ else:
404
+ semantic_history = None
405
+ # load models if not yet exist
406
+ global models
407
+ global models_devices
408
+ if "text" not in models:
409
+ preload_models()
410
+ model_container = models["text"]
411
+ model = model_container["model"]
412
+ tokenizer = model_container["tokenizer"]
413
+ encoded_text = np.array(_tokenize(tokenizer, text)) + TEXT_ENCODING_OFFSET
414
+ if OFFLOAD_CPU:
415
+ model.to(models_devices["text"])
416
+ device = next(model.parameters()).device
417
+ if len(encoded_text) > 256:
418
+ p = round((len(encoded_text) - 256) / len(encoded_text) * 100, 1)
419
+ logger.warning(f"warning, text too long, lopping of last {p}%")
420
+ encoded_text = encoded_text[:256]
421
+ encoded_text = np.pad(
422
+ encoded_text,
423
+ (0, 256 - len(encoded_text)),
424
+ constant_values=TEXT_PAD_TOKEN,
425
+ mode="constant",
426
+ )
427
+ if semantic_history is not None:
428
+ semantic_history = semantic_history.astype(np.int64)
429
+ # lop off if history is too long, pad if needed
430
+ semantic_history = semantic_history[-256:]
431
+ semantic_history = np.pad(
432
+ semantic_history,
433
+ (0, 256 - len(semantic_history)),
434
+ constant_values=SEMANTIC_PAD_TOKEN,
435
+ mode="constant",
436
+ )
437
+ else:
438
+ semantic_history = np.array([SEMANTIC_PAD_TOKEN] * 256)
439
+ x = torch.from_numpy(
440
+ np.hstack([
441
+ encoded_text, semantic_history, np.array([SEMANTIC_INFER_TOKEN])
442
+ ]).astype(np.int64)
443
+ )[None]
444
+ assert x.shape[1] == 256 + 256 + 1
445
+ with _inference_mode():
446
+ x = x.to(device)
447
+ n_tot_steps = 768
448
+ # custom tqdm updates since we don't know when eos will occur
449
+ pbar = tqdm.tqdm(disable=silent, total=n_tot_steps)
450
+ pbar_state = 0
451
+ tot_generated_duration_s = 0
452
+ kv_cache = None
453
+ for n in range(n_tot_steps):
454
+ if use_kv_caching and kv_cache is not None:
455
+ x_input = x[:, [-1]]
456
+ else:
457
+ x_input = x
458
+ logits, kv_cache = model(
459
+ x_input, merge_context=True, use_cache=use_kv_caching, past_kv=kv_cache
460
+ )
461
+ relevant_logits = logits[0, 0, :SEMANTIC_VOCAB_SIZE]
462
+ if allow_early_stop:
463
+ relevant_logits = torch.hstack(
464
+ (relevant_logits, logits[0, 0, [SEMANTIC_PAD_TOKEN]]) # eos
465
+ )
466
+ if top_p is not None:
467
+ # faster to convert to numpy
468
+ original_device = relevant_logits.device
469
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
470
+ sorted_indices = np.argsort(relevant_logits)[::-1]
471
+ sorted_logits = relevant_logits[sorted_indices]
472
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
473
+ sorted_indices_to_remove = cumulative_probs > top_p
474
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
475
+ sorted_indices_to_remove[0] = False
476
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
477
+ relevant_logits = torch.from_numpy(relevant_logits)
478
+ relevant_logits = relevant_logits.to(original_device)
479
+ if top_k is not None:
480
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
481
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
482
+ probs = F.softmax(relevant_logits / temp, dim=-1)
483
+ item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
484
+ if allow_early_stop and (
485
+ item_next == SEMANTIC_VOCAB_SIZE
486
+ or (min_eos_p is not None and probs[-1] >= min_eos_p)
487
+ ):
488
+ # eos found, so break
489
+ pbar.update(n - pbar_state)
490
+ break
491
+ x = torch.cat((x, item_next[None]), dim=1)
492
+ tot_generated_duration_s += 1 / SEMANTIC_RATE_HZ
493
+ if max_gen_duration_s is not None and tot_generated_duration_s > max_gen_duration_s:
494
+ pbar.update(n - pbar_state)
495
+ break
496
+ if n == n_tot_steps - 1:
497
+ pbar.update(n - pbar_state)
498
+ break
499
+ del logits, relevant_logits, probs, item_next
500
+
501
+ if n > pbar_state:
502
+ if n > pbar.total:
503
+ pbar.total = n
504
+ pbar.update(n - pbar_state)
505
+ pbar_state = n
506
+ pbar.total = n
507
+ pbar.refresh()
508
+ pbar.close()
509
+ out = x.detach().cpu().numpy().squeeze()[256 + 256 + 1 :]
510
+ if OFFLOAD_CPU:
511
+ model.to("cpu")
512
+ assert all(0 <= out) and all(out < SEMANTIC_VOCAB_SIZE)
513
+ _clear_cuda_cache()
514
+ return out
515
+
516
+
517
+ def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):
518
+ assert len(arr.shape) == 2
519
+ arr = arr.copy()
520
+ if offset_size is not None:
521
+ for n in range(1, arr.shape[0]):
522
+ arr[n, :] += offset_size * n
523
+ flat_arr = arr.ravel("F")
524
+ return flat_arr
525
+
526
+
527
+ COARSE_SEMANTIC_PAD_TOKEN = 12_048
528
+ COARSE_INFER_TOKEN = 12_050
529
+
530
+
531
+ def generate_coarse(
532
+ x_semantic,
533
+ history_prompt=None,
534
+ temp=0.7,
535
+ top_k=None,
536
+ top_p=None,
537
+ silent=False,
538
+ max_coarse_history=630, # min 60 (faster), max 630 (more context)
539
+ sliding_window_len=60,
540
+ use_kv_caching=False,
541
+ ):
542
+ """Generate coarse audio codes from semantic tokens."""
543
+ assert (
544
+ isinstance(x_semantic, np.ndarray)
545
+ and len(x_semantic.shape) == 1
546
+ and len(x_semantic) > 0
547
+ and x_semantic.min() >= 0
548
+ and x_semantic.max() <= SEMANTIC_VOCAB_SIZE - 1
549
+ )
550
+ assert 60 <= max_coarse_history <= 630
551
+ assert max_coarse_history + sliding_window_len <= 1024 - 256
552
+ semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS
553
+ max_semantic_history = int(np.floor(max_coarse_history / semantic_to_coarse_ratio))
554
+ if history_prompt is not None:
555
+ history_prompt = _load_history_prompt(history_prompt)
556
+ x_semantic_history = history_prompt["semantic_prompt"]
557
+ x_coarse_history = history_prompt["coarse_prompt"]
558
+ assert (
559
+ isinstance(x_semantic_history, np.ndarray)
560
+ and len(x_semantic_history.shape) == 1
561
+ and len(x_semantic_history) > 0
562
+ and x_semantic_history.min() >= 0
563
+ and x_semantic_history.max() <= SEMANTIC_VOCAB_SIZE - 1
564
+ and isinstance(x_coarse_history, np.ndarray)
565
+ and len(x_coarse_history.shape) == 2
566
+ and x_coarse_history.shape[0] == N_COARSE_CODEBOOKS
567
+ and x_coarse_history.shape[-1] >= 0
568
+ and x_coarse_history.min() >= 0
569
+ and x_coarse_history.max() <= CODEBOOK_SIZE - 1
570
+ and (
571
+ round(x_coarse_history.shape[-1] / len(x_semantic_history), 1)
572
+ == round(semantic_to_coarse_ratio / N_COARSE_CODEBOOKS, 1)
573
+ )
574
+ )
575
+ x_coarse_history = _flatten_codebooks(x_coarse_history) + SEMANTIC_VOCAB_SIZE
576
+ # trim histories correctly
577
+ n_semantic_hist_provided = np.min(
578
+ [
579
+ max_semantic_history,
580
+ len(x_semantic_history) - len(x_semantic_history) % 2,
581
+ int(np.floor(len(x_coarse_history) / semantic_to_coarse_ratio)),
582
+ ]
583
+ )
584
+ n_coarse_hist_provided = int(round(n_semantic_hist_provided * semantic_to_coarse_ratio))
585
+ x_semantic_history = x_semantic_history[-n_semantic_hist_provided:].astype(np.int32)
586
+ x_coarse_history = x_coarse_history[-n_coarse_hist_provided:].astype(np.int32)
587
+ # TODO: bit of a hack for time alignment (sounds better)
588
+ x_coarse_history = x_coarse_history[:-2]
589
+ else:
590
+ x_semantic_history = np.array([], dtype=np.int32)
591
+ x_coarse_history = np.array([], dtype=np.int32)
592
+ # load models if not yet exist
593
+ global models
594
+ global models_devices
595
+ if "coarse" not in models:
596
+ preload_models()
597
+ model = models["coarse"]
598
+ if OFFLOAD_CPU:
599
+ model.to(models_devices["coarse"])
600
+ device = next(model.parameters()).device
601
+ # start loop
602
+ n_steps = int(
603
+ round(
604
+ np.floor(len(x_semantic) * semantic_to_coarse_ratio / N_COARSE_CODEBOOKS)
605
+ * N_COARSE_CODEBOOKS
606
+ )
607
+ )
608
+ assert n_steps > 0 and n_steps % N_COARSE_CODEBOOKS == 0
609
+ x_semantic = np.hstack([x_semantic_history, x_semantic]).astype(np.int32)
610
+ x_coarse = x_coarse_history.astype(np.int32)
611
+ base_semantic_idx = len(x_semantic_history)
612
+ with _inference_mode():
613
+ x_semantic_in = torch.from_numpy(x_semantic)[None].to(device)
614
+ x_coarse_in = torch.from_numpy(x_coarse)[None].to(device)
615
+ n_window_steps = int(np.ceil(n_steps / sliding_window_len))
616
+ n_step = 0
617
+ for _ in tqdm.tqdm(range(n_window_steps), total=n_window_steps, disable=silent):
618
+ semantic_idx = base_semantic_idx + int(round(n_step / semantic_to_coarse_ratio))
619
+ # pad from right side
620
+ x_in = x_semantic_in[:, np.max([0, semantic_idx - max_semantic_history]) :]
621
+ x_in = x_in[:, :256]
622
+ x_in = F.pad(
623
+ x_in,
624
+ (0, 256 - x_in.shape[-1]),
625
+ "constant",
626
+ COARSE_SEMANTIC_PAD_TOKEN,
627
+ )
628
+ x_in = torch.hstack(
629
+ [
630
+ x_in,
631
+ torch.tensor([COARSE_INFER_TOKEN])[None].to(device),
632
+ x_coarse_in[:, -max_coarse_history:],
633
+ ]
634
+ )
635
+ kv_cache = None
636
+ for _ in range(sliding_window_len):
637
+ if n_step >= n_steps:
638
+ continue
639
+ is_major_step = n_step % N_COARSE_CODEBOOKS == 0
640
+
641
+ if use_kv_caching and kv_cache is not None:
642
+ x_input = x_in[:, [-1]]
643
+ else:
644
+ x_input = x_in
645
+
646
+ logits, kv_cache = model(x_input, use_cache=use_kv_caching, past_kv=kv_cache)
647
+ logit_start_idx = (
648
+ SEMANTIC_VOCAB_SIZE + (1 - int(is_major_step)) * CODEBOOK_SIZE
649
+ )
650
+ logit_end_idx = (
651
+ SEMANTIC_VOCAB_SIZE + (2 - int(is_major_step)) * CODEBOOK_SIZE
652
+ )
653
+ relevant_logits = logits[0, 0, logit_start_idx:logit_end_idx]
654
+ if top_p is not None:
655
+ # faster to convert to numpy
656
+ original_device = relevant_logits.device
657
+ relevant_logits = relevant_logits.detach().cpu().type(torch.float32).numpy()
658
+ sorted_indices = np.argsort(relevant_logits)[::-1]
659
+ sorted_logits = relevant_logits[sorted_indices]
660
+ cumulative_probs = np.cumsum(softmax(sorted_logits))
661
+ sorted_indices_to_remove = cumulative_probs > top_p
662
+ sorted_indices_to_remove[1:] = sorted_indices_to_remove[:-1].copy()
663
+ sorted_indices_to_remove[0] = False
664
+ relevant_logits[sorted_indices[sorted_indices_to_remove]] = -np.inf
665
+ relevant_logits = torch.from_numpy(relevant_logits)
666
+ relevant_logits = relevant_logits.to(original_device)
667
+ if top_k is not None:
668
+ v, _ = torch.topk(relevant_logits, min(top_k, relevant_logits.size(-1)))
669
+ relevant_logits[relevant_logits < v[-1]] = -float("Inf")
670
+ probs = F.softmax(relevant_logits / temp, dim=-1)
671
+ item_next = torch.multinomial(probs, num_samples=1).to(torch.int32)
672
+ item_next += logit_start_idx
673
+ x_coarse_in = torch.cat((x_coarse_in, item_next[None]), dim=1)
674
+ x_in = torch.cat((x_in, item_next[None]), dim=1)
675
+ del logits, relevant_logits, probs, item_next
676
+ n_step += 1
677
+ del x_in
678
+ del x_semantic_in
679
+ if OFFLOAD_CPU:
680
+ model.to("cpu")
681
+ gen_coarse_arr = x_coarse_in.detach().cpu().numpy().squeeze()[len(x_coarse_history) :]
682
+ del x_coarse_in
683
+ assert len(gen_coarse_arr) == n_steps
684
+ gen_coarse_audio_arr = gen_coarse_arr.reshape(-1, N_COARSE_CODEBOOKS).T - SEMANTIC_VOCAB_SIZE
685
+ for n in range(1, N_COARSE_CODEBOOKS):
686
+ gen_coarse_audio_arr[n, :] -= n * CODEBOOK_SIZE
687
+ _clear_cuda_cache()
688
+ return gen_coarse_audio_arr
689
+
690
+
691
+ def generate_fine(
692
+ x_coarse_gen,
693
+ history_prompt=None,
694
+ temp=0.5,
695
+ silent=True,
696
+ ):
697
+ """Generate full audio codes from coarse audio codes."""
698
+ assert (
699
+ isinstance(x_coarse_gen, np.ndarray)
700
+ and len(x_coarse_gen.shape) == 2
701
+ and 1 <= x_coarse_gen.shape[0] <= N_FINE_CODEBOOKS - 1
702
+ and x_coarse_gen.shape[1] > 0
703
+ and x_coarse_gen.min() >= 0
704
+ and x_coarse_gen.max() <= CODEBOOK_SIZE - 1
705
+ )
706
+ if history_prompt is not None:
707
+ history_prompt = _load_history_prompt(history_prompt)
708
+ x_fine_history = history_prompt["fine_prompt"]
709
+ assert (
710
+ isinstance(x_fine_history, np.ndarray)
711
+ and len(x_fine_history.shape) == 2
712
+ and x_fine_history.shape[0] == N_FINE_CODEBOOKS
713
+ and x_fine_history.shape[1] >= 0
714
+ and x_fine_history.min() >= 0
715
+ and x_fine_history.max() <= CODEBOOK_SIZE - 1
716
+ )
717
+ else:
718
+ x_fine_history = None
719
+ n_coarse = x_coarse_gen.shape[0]
720
+ # load models if not yet exist
721
+ global models
722
+ global models_devices
723
+ if "fine" not in models:
724
+ preload_models()
725
+ model = models["fine"]
726
+ if OFFLOAD_CPU:
727
+ model.to(models_devices["fine"])
728
+ device = next(model.parameters()).device
729
+ # make input arr
730
+ in_arr = np.vstack(
731
+ [
732
+ x_coarse_gen,
733
+ np.zeros((N_FINE_CODEBOOKS - n_coarse, x_coarse_gen.shape[1]))
734
+ + CODEBOOK_SIZE, # padding
735
+ ]
736
+ ).astype(np.int32)
737
+ # prepend history if available (max 512)
738
+ if x_fine_history is not None:
739
+ x_fine_history = x_fine_history.astype(np.int32)
740
+ in_arr = np.hstack(
741
+ [
742
+ x_fine_history[:, -512:].astype(np.int32),
743
+ in_arr,
744
+ ]
745
+ )
746
+ n_history = x_fine_history[:, -512:].shape[1]
747
+ else:
748
+ n_history = 0
749
+ n_remove_from_end = 0
750
+ # need to pad if too short (since non-causal model)
751
+ if in_arr.shape[1] < 1024:
752
+ n_remove_from_end = 1024 - in_arr.shape[1]
753
+ in_arr = np.hstack(
754
+ [
755
+ in_arr,
756
+ np.zeros((N_FINE_CODEBOOKS, n_remove_from_end), dtype=np.int32) + CODEBOOK_SIZE,
757
+ ]
758
+ )
759
+ # we can be lazy about fractional loop and just keep overwriting codebooks
760
+ n_loops = np.max([0, int(np.ceil((x_coarse_gen.shape[1] - (1024 - n_history)) / 512))]) + 1
761
+ with _inference_mode():
762
+ in_arr = torch.tensor(in_arr.T).to(device)
763
+ for n in tqdm.tqdm(range(n_loops), disable=silent):
764
+ start_idx = np.min([n * 512, in_arr.shape[0] - 1024])
765
+ start_fill_idx = np.min([n_history + n * 512, in_arr.shape[0] - 512])
766
+ rel_start_fill_idx = start_fill_idx - start_idx
767
+ in_buffer = in_arr[start_idx : start_idx + 1024, :][None]
768
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
769
+ logits = model(nn, in_buffer)
770
+ if temp is None:
771
+ relevant_logits = logits[0, rel_start_fill_idx:, :CODEBOOK_SIZE]
772
+ codebook_preds = torch.argmax(relevant_logits, -1)
773
+ else:
774
+ relevant_logits = logits[0, :, :CODEBOOK_SIZE] / temp
775
+ probs = F.softmax(relevant_logits, dim=-1)
776
+ codebook_preds = torch.multinomial(
777
+ probs[rel_start_fill_idx:1024], num_samples=1
778
+ ).reshape(-1)
779
+ codebook_preds = codebook_preds.to(torch.int32)
780
+ in_buffer[0, rel_start_fill_idx:, nn] = codebook_preds
781
+ del logits, codebook_preds
782
+ # transfer over info into model_in and convert to numpy
783
+ for nn in range(n_coarse, N_FINE_CODEBOOKS):
784
+ in_arr[
785
+ start_fill_idx : start_fill_idx + (1024 - rel_start_fill_idx), nn
786
+ ] = in_buffer[0, rel_start_fill_idx:, nn]
787
+ del in_buffer
788
+ gen_fine_arr = in_arr.detach().cpu().numpy().squeeze().T
789
+ del in_arr
790
+ if OFFLOAD_CPU:
791
+ model.to("cpu")
792
+ gen_fine_arr = gen_fine_arr[:, n_history:]
793
+ if n_remove_from_end > 0:
794
+ gen_fine_arr = gen_fine_arr[:, :-n_remove_from_end]
795
+ assert gen_fine_arr.shape[-1] == x_coarse_gen.shape[-1]
796
+ _clear_cuda_cache()
797
+ return gen_fine_arr
798
+
799
+
800
+ def codec_decode(fine_tokens):
801
+ """Turn quantized audio codes into audio array using encodec."""
802
+ # load models if not yet exist
803
+ global models
804
+ global models_devices
805
+ if "codec" not in models:
806
+ preload_models()
807
+ model = models["codec"]
808
+ if OFFLOAD_CPU:
809
+ model.to(models_devices["codec"])
810
+ device = next(model.parameters()).device
811
+ arr = torch.from_numpy(fine_tokens)[None]
812
+ arr = arr.to(device)
813
+ arr = arr.transpose(0, 1)
814
+ emb = model.quantizer.decode(arr)
815
+ out = model.decoder(emb)
816
+ audio_arr = out.detach().cpu().numpy().squeeze()
817
+ del arr, emb, out
818
+ if OFFLOAD_CPU:
819
+ model.to("cpu")
820
+ return audio_arr
model.py ADDED
@@ -0,0 +1,218 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ import math
6
+ from dataclasses import dataclass
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ class LayerNorm(nn.Module):
13
+ """ LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """
14
+
15
+ def __init__(self, ndim, bias):
16
+ super().__init__()
17
+ self.weight = nn.Parameter(torch.ones(ndim))
18
+ self.bias = nn.Parameter(torch.zeros(ndim)) if bias else None
19
+
20
+ def forward(self, input):
21
+ return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)
22
+
23
+ class CausalSelfAttention(nn.Module):
24
+
25
+ def __init__(self, config):
26
+ super().__init__()
27
+ assert config.n_embd % config.n_head == 0
28
+ # key, query, value projections for all heads, but in a batch
29
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
30
+ # output projection
31
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
32
+ # regularization
33
+ self.attn_dropout = nn.Dropout(config.dropout)
34
+ self.resid_dropout = nn.Dropout(config.dropout)
35
+ self.n_head = config.n_head
36
+ self.n_embd = config.n_embd
37
+ self.dropout = config.dropout
38
+ # flash attention make GPU go brrrrr but support is only in PyTorch nightly and still a bit scary
39
+ self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')
40
+ if not self.flash:
41
+ # print("WARNING: using slow attention. Flash Attention atm needs PyTorch nightly and dropout=0.0")
42
+ # causal mask to ensure that attention is only applied to the left in the input sequence
43
+ self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
44
+ .view(1, 1, config.block_size, config.block_size))
45
+
46
+ def forward(self, x, past_kv=None, use_cache=False):
47
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
48
+
49
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
50
+ q, k ,v = self.c_attn(x).split(self.n_embd, dim=2)
51
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
52
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
53
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
54
+
55
+ if past_kv is not None:
56
+ past_key = past_kv[0]
57
+ past_value = past_kv[1]
58
+ k = torch.cat((past_key, k), dim=-2)
59
+ v = torch.cat((past_value, v), dim=-2)
60
+
61
+ FULL_T = k.shape[-2]
62
+
63
+ if use_cache is True:
64
+ present = (k, v)
65
+ else:
66
+ present = None
67
+
68
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
69
+ if self.flash:
70
+ # efficient attention using Flash Attention CUDA kernels
71
+ if past_kv is not None:
72
+ # When `past_kv` is provided, we're doing incremental decoding and `q.shape[2] == 1`: q only contains
73
+ # the query for the last token. scaled_dot_product_attention interprets this as the first token in the
74
+ # sequence, so if is_causal=True it will mask out all attention from it. This is not what we want, so
75
+ # to work around this we set is_causal=False.
76
+ is_causal = False
77
+ else:
78
+ is_causal = True
79
+
80
+ y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=self.dropout, is_causal=is_causal)
81
+ else:
82
+ # manual implementation of attention
83
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
84
+ att = att.masked_fill(self.bias[:,:,FULL_T-T:FULL_T,:FULL_T] == 0, float('-inf'))
85
+ att = F.softmax(att, dim=-1)
86
+ att = self.attn_dropout(att)
87
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
88
+ y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side
89
+
90
+ # output projection
91
+ y = self.resid_dropout(self.c_proj(y))
92
+ return (y, present)
93
+
94
+ class MLP(nn.Module):
95
+
96
+ def __init__(self, config):
97
+ super().__init__()
98
+ self.c_fc = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)
99
+ self.c_proj = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)
100
+ self.dropout = nn.Dropout(config.dropout)
101
+ self.gelu = nn.GELU()
102
+
103
+ def forward(self, x):
104
+ x = self.c_fc(x)
105
+ x = self.gelu(x)
106
+ x = self.c_proj(x)
107
+ x = self.dropout(x)
108
+ return x
109
+
110
+ class Block(nn.Module):
111
+
112
+ def __init__(self, config, layer_idx):
113
+ super().__init__()
114
+ self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)
115
+ self.attn = CausalSelfAttention(config)
116
+ self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)
117
+ self.mlp = MLP(config)
118
+ self.layer_idx = layer_idx
119
+
120
+ def forward(self, x, past_kv=None, use_cache=False):
121
+ attn_output, prev_kvs = self.attn(self.ln_1(x), past_kv=past_kv, use_cache=use_cache)
122
+ x = x + attn_output
123
+ x = x + self.mlp(self.ln_2(x))
124
+ return (x, prev_kvs)
125
+
126
+ @dataclass
127
+ class GPTConfig:
128
+ block_size: int = 1024
129
+ input_vocab_size: int = 10_048
130
+ output_vocab_size: int = 10_048
131
+ n_layer: int = 12
132
+ n_head: int = 12
133
+ n_embd: int = 768
134
+ dropout: float = 0.0
135
+ bias: bool = True # True: bias in Linears and LayerNorms, like GPT-2. False: a bit better and faster
136
+
137
+ class GPT(nn.Module):
138
+
139
+ def __init__(self, config):
140
+ super().__init__()
141
+ assert config.input_vocab_size is not None
142
+ assert config.output_vocab_size is not None
143
+ assert config.block_size is not None
144
+ self.config = config
145
+
146
+ self.transformer = nn.ModuleDict(dict(
147
+ wte = nn.Embedding(config.input_vocab_size, config.n_embd),
148
+ wpe = nn.Embedding(config.block_size, config.n_embd),
149
+ drop = nn.Dropout(config.dropout),
150
+ h = nn.ModuleList([Block(config, idx) for idx in range(config.n_layer)]),
151
+ ln_f = LayerNorm(config.n_embd, bias=config.bias),
152
+ ))
153
+ self.lm_head = nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
154
+
155
+ def get_num_params(self, non_embedding=True):
156
+ """
157
+ Return the number of parameters in the model.
158
+ For non-embedding count (default), the position embeddings get subtracted.
159
+ The token embeddings would too, except due to the parameter sharing these
160
+ params are actually used as weights in the final layer, so we include them.
161
+ """
162
+ n_params = sum(p.numel() for p in self.parameters())
163
+ if non_embedding:
164
+ n_params -= self.transformer.wte.weight.numel()
165
+ n_params -= self.transformer.wpe.weight.numel()
166
+ return n_params
167
+
168
+ def forward(self, idx, merge_context=False, past_kv=None, position_ids=None, use_cache=False):
169
+ device = idx.device
170
+ b, t = idx.size()
171
+ if past_kv is not None:
172
+ assert t == 1
173
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
174
+ else:
175
+ if merge_context:
176
+ assert(idx.shape[1] >= 256+256+1)
177
+ t = idx.shape[1] - 256
178
+ else:
179
+ assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
180
+
181
+ # forward the GPT model itself
182
+ if merge_context:
183
+ tok_emb = torch.cat([
184
+ self.transformer.wte(idx[:,:256]) + self.transformer.wte(idx[:,256:256+256]),
185
+ self.transformer.wte(idx[:,256+256:])
186
+ ], dim=1)
187
+ else:
188
+ tok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
189
+
190
+ if past_kv is None:
191
+ past_length = 0
192
+ past_kv = tuple([None] * len(self.transformer.h))
193
+ else:
194
+ past_length = past_kv[0][0].size(-2)
195
+
196
+ if position_ids is None:
197
+ position_ids = torch.arange(past_length, t + past_length, dtype=torch.long, device=device)
198
+ position_ids = position_ids.unsqueeze(0) # shape (1, t)
199
+ assert position_ids.shape == (1, t)
200
+
201
+ pos_emb = self.transformer.wpe(position_ids) # position embeddings of shape (1, t, n_embd)
202
+
203
+ x = self.transformer.drop(tok_emb + pos_emb)
204
+
205
+ new_kv = () if use_cache else None
206
+
207
+ for i, (block, past_layer_kv) in enumerate(zip(self.transformer.h, past_kv)):
208
+ x, kv = block(x, past_kv=past_layer_kv, use_cache=use_cache)
209
+
210
+ if use_cache:
211
+ new_kv = new_kv + (kv,)
212
+
213
+ x = self.transformer.ln_f(x)
214
+
215
+ # inference-time mini-optimization: only forward the lm_head on the very last position
216
+ logits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dim
217
+
218
+ return (logits, new_kv)
model_fine.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Much of this code is adapted from Andrej Karpathy's NanoGPT
3
+ (https://github.com/karpathy/nanoGPT)
4
+ """
5
+ from dataclasses import dataclass
6
+ import math
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+ from torch.nn import functional as F
11
+
12
+ from .model import GPT, GPTConfig, MLP
13
+
14
+
15
+ class NonCausalSelfAttention(nn.Module):
16
+ def __init__(self, config):
17
+ super().__init__()
18
+ assert config.n_embd % config.n_head == 0
19
+ # key, query, value projections for all heads, but in a batch
20
+ self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
21
+ # output projection
22
+ self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
23
+ # regularization
24
+ self.attn_dropout = nn.Dropout(config.dropout)
25
+ self.resid_dropout = nn.Dropout(config.dropout)
26
+ self.n_head = config.n_head
27
+ self.n_embd = config.n_embd
28
+ self.dropout = config.dropout
29
+ # flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0
30
+ self.flash = (
31
+ hasattr(torch.nn.functional, "scaled_dot_product_attention")
32
+ )
33
+
34
+ def forward(self, x):
35
+ B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)
36
+
37
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
38
+ q, k, v = self.c_attn(x).split(self.n_embd, dim=2)
39
+ k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
40
+ q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
41
+ v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)
42
+
43
+ # causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)
44
+ if self.flash:
45
+ # efficient attention using Flash Attention CUDA kernels
46
+ y = torch.nn.functional.scaled_dot_product_attention(
47
+ q, k, v, attn_mask=None, dropout_p=self.dropout, is_causal=False
48
+ )
49
+ else:
50
+ # manual implementation of attention
51
+ att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
52
+ att = F.softmax(att, dim=-1)
53
+ att = self.attn_dropout(att)
54
+ y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
55
+ y = (
56
+ y.transpose(1, 2).contiguous().view(B, T, C)
57
+ ) # re-assemble all head outputs side by side
58
+
59
+ # output projection
60
+ y = self.resid_dropout(self.c_proj(y))
61
+ return y
62
+
63
+
64
+ class FineBlock(nn.Module):
65
+ def __init__(self, config):
66
+ super().__init__()
67
+ self.ln_1 = nn.LayerNorm(config.n_embd)
68
+ self.attn = NonCausalSelfAttention(config)
69
+ self.ln_2 = nn.LayerNorm(config.n_embd)
70
+ self.mlp = MLP(config)
71
+
72
+ def forward(self, x):
73
+ x = x + self.attn(self.ln_1(x))
74
+ x = x + self.mlp(self.ln_2(x))
75
+ return x
76
+
77
+
78
+ class FineGPT(GPT):
79
+ def __init__(self, config):
80
+ super().__init__(config)
81
+ del self.lm_head
82
+ self.config = config
83
+ self.n_codes_total = config.n_codes_total
84
+ self.transformer = nn.ModuleDict(
85
+ dict(
86
+ wtes=nn.ModuleList(
87
+ [
88
+ nn.Embedding(config.input_vocab_size, config.n_embd)
89
+ for _ in range(config.n_codes_total)
90
+ ]
91
+ ),
92
+ wpe=nn.Embedding(config.block_size, config.n_embd),
93
+ drop=nn.Dropout(config.dropout),
94
+ h=nn.ModuleList([FineBlock(config) for _ in range(config.n_layer)]),
95
+ ln_f=nn.LayerNorm(config.n_embd),
96
+ )
97
+ )
98
+ self.lm_heads = nn.ModuleList(
99
+ [
100
+ nn.Linear(config.n_embd, config.output_vocab_size, bias=False)
101
+ for _ in range(config.n_codes_given, self.n_codes_total)
102
+ ]
103
+ )
104
+ for i in range(self.n_codes_total - config.n_codes_given):
105
+ self.transformer.wtes[i + 1].weight = self.lm_heads[i].weight
106
+
107
+ def forward(self, pred_idx, idx):
108
+ device = idx.device
109
+ b, t, codes = idx.size()
110
+ assert (
111
+ t <= self.config.block_size
112
+ ), f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"
113
+ assert pred_idx > 0, "cannot predict 0th codebook"
114
+ assert codes == self.n_codes_total, (b, t, codes)
115
+ pos = torch.arange(0, t, dtype=torch.long, device=device).unsqueeze(0) # shape (1, t)
116
+
117
+ # forward the GPT model itself
118
+ tok_embs = [
119
+ wte(idx[:, :, i]).unsqueeze(-1) for i, wte in enumerate(self.transformer.wtes)
120
+ ] # token embeddings of shape (b, t, n_embd)
121
+ tok_emb = torch.cat(tok_embs, dim=-1)
122
+ pos_emb = self.transformer.wpe(pos) # position embeddings of shape (1, t, n_embd)
123
+ x = tok_emb[:, :, :, : pred_idx + 1].sum(dim=-1)
124
+ x = self.transformer.drop(x + pos_emb)
125
+ for block in self.transformer.h:
126
+ x = block(x)
127
+ x = self.transformer.ln_f(x)
128
+ logits = self.lm_heads[pred_idx - self.config.n_codes_given](x)
129
+ return logits
130
+
131
+ def get_num_params(self, non_embedding=True):
132
+ """
133
+ Return the number of parameters in the model.
134
+ For non-embedding count (default), the position embeddings get subtracted.
135
+ The token embeddings would too, except due to the parameter sharing these
136
+ params are actually used as weights in the final layer, so we include them.
137
+ """
138
+ n_params = sum(p.numel() for p in self.parameters())
139
+ if non_embedding:
140
+ for wte in self.transformer.wtes:
141
+ n_params -= wte.weight.numel()
142
+ n_params -= self.transformer.wpe.weight.numel()
143
+ return n_params
144
+
145
+
146
+ @dataclass
147
+ class FineGPTConfig(GPTConfig):
148
+ n_codes_total: int = 8
149
+ n_codes_given: int = 1