mrfakename commited on
Commit
775c5c6
·
1 Parent(s): fc3fbd4
.gitattributes DELETED
@@ -1,37 +0,0 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
36
- prompts/countdown_mono.wav filter=lfs diff=lfs merge=lfs -text
37
- prompts/toaskanymore.wav filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
LICENSE DELETED
@@ -1,201 +0,0 @@
1
- Apache License
2
- Version 2.0, January 2004
3
- http://www.apache.org/licenses/
4
-
5
- TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
-
7
- 1. Definitions.
8
-
9
- "License" shall mean the terms and conditions for use, reproduction,
10
- and distribution as defined by Sections 1 through 9 of this document.
11
-
12
- "Licensor" shall mean the copyright owner or entity authorized by
13
- the copyright owner that is granting the License.
14
-
15
- "Legal Entity" shall mean the union of the acting entity and all
16
- other entities that control, are controlled by, or are under common
17
- control with that entity. For the purposes of this definition,
18
- "control" means (i) the power, direct or indirect, to cause the
19
- direction or management of such entity, whether by contract or
20
- otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
- outstanding shares, or (iii) beneficial ownership of such entity.
22
-
23
- "You" (or "Your") shall mean an individual or Legal Entity
24
- exercising permissions granted by this License.
25
-
26
- "Source" form shall mean the preferred form for making modifications,
27
- including but not limited to software source code, documentation
28
- source, and configuration files.
29
-
30
- "Object" form shall mean any form resulting from mechanical
31
- transformation or translation of a Source form, including but
32
- not limited to compiled object code, generated documentation,
33
- and conversions to other media types.
34
-
35
- "Work" shall mean the work of authorship, whether in Source or
36
- Object form, made available under the License, as indicated by a
37
- copyright notice that is included in or attached to the work
38
- (an example is provided in the Appendix below).
39
-
40
- "Derivative Works" shall mean any work, whether in Source or Object
41
- form, that is based on (or derived from) the Work and for which the
42
- editorial revisions, annotations, elaborations, or other modifications
43
- represent, as a whole, an original work of authorship. For the purposes
44
- of this License, Derivative Works shall not include works that remain
45
- separable from, or merely link (or bind by name) to the interfaces of,
46
- the Work and Derivative Works thereof.
47
-
48
- "Contribution" shall mean any work of authorship, including
49
- the original version of the Work and any modifications or additions
50
- to that Work or Derivative Works thereof, that is intentionally
51
- submitted to Licensor for inclusion in the Work by the copyright owner
52
- or by an individual or Legal Entity authorized to submit on behalf of
53
- the copyright owner. For the purposes of this definition, "submitted"
54
- means any form of electronic, verbal, or written communication sent
55
- to the Licensor or its representatives, including but not limited to
56
- communication on electronic mailing lists, source code control systems,
57
- and issue tracking systems that are managed by, or on behalf of, the
58
- Licensor for the purpose of discussing and improving the Work, but
59
- excluding communication that is conspicuously marked or otherwise
60
- designated in writing by the copyright owner as "Not a Contribution."
61
-
62
- "Contributor" shall mean Licensor and any individual or Legal Entity
63
- on behalf of whom a Contribution has been received by Licensor and
64
- subsequently incorporated within the Work.
65
-
66
- 2. Grant of Copyright License. Subject to the terms and conditions of
67
- this License, each Contributor hereby grants to You a perpetual,
68
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
- copyright license to reproduce, prepare Derivative Works of,
70
- publicly display, publicly perform, sublicense, and distribute the
71
- Work and such Derivative Works in Source or Object form.
72
-
73
- 3. Grant of Patent License. Subject to the terms and conditions of
74
- this License, each Contributor hereby grants to You a perpetual,
75
- worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
- (except as stated in this section) patent license to make, have made,
77
- use, offer to sell, sell, import, and otherwise transfer the Work,
78
- where such license applies only to those patent claims licensable
79
- by such Contributor that are necessarily infringed by their
80
- Contribution(s) alone or by combination of their Contribution(s)
81
- with the Work to which such Contribution(s) was submitted. If You
82
- institute patent litigation against any entity (including a
83
- cross-claim or counterclaim in a lawsuit) alleging that the Work
84
- or a Contribution incorporated within the Work constitutes direct
85
- or contributory patent infringement, then any patent licenses
86
- granted to You under this License for that Work shall terminate
87
- as of the date such litigation is filed.
88
-
89
- 4. Redistribution. You may reproduce and distribute copies of the
90
- Work or Derivative Works thereof in any medium, with or without
91
- modifications, and in Source or Object form, provided that You
92
- meet the following conditions:
93
-
94
- (a) You must give any other recipients of the Work or
95
- Derivative Works a copy of this License; and
96
-
97
- (b) You must cause any modified files to carry prominent notices
98
- stating that You changed the files; and
99
-
100
- (c) You must retain, in the Source form of any Derivative Works
101
- that You distribute, all copyright, patent, trademark, and
102
- attribution notices from the Source form of the Work,
103
- excluding those notices that do not pertain to any part of
104
- the Derivative Works; and
105
-
106
- (d) If the Work includes a "NOTICE" text file as part of its
107
- distribution, then any Derivative Works that You distribute must
108
- include a readable copy of the attribution notices contained
109
- within such NOTICE file, excluding those notices that do not
110
- pertain to any part of the Derivative Works, in at least one
111
- of the following places: within a NOTICE text file distributed
112
- as part of the Derivative Works; within the Source form or
113
- documentation, if provided along with the Derivative Works; or,
114
- within a display generated by the Derivative Works, if and
115
- wherever such third-party notices normally appear. The contents
116
- of the NOTICE file are for informational purposes only and
117
- do not modify the License. You may add Your own attribution
118
- notices within Derivative Works that You distribute, alongside
119
- or as an addendum to the NOTICE text from the Work, provided
120
- that such additional attribution notices cannot be construed
121
- as modifying the License.
122
-
123
- You may add Your own copyright statement to Your modifications and
124
- may provide additional or different license terms and conditions
125
- for use, reproduction, or distribution of Your modifications, or
126
- for any such Derivative Works as a whole, provided Your use,
127
- reproduction, and distribution of the Work otherwise complies with
128
- the conditions stated in this License.
129
-
130
- 5. Submission of Contributions. Unless You explicitly state otherwise,
131
- any Contribution intentionally submitted for inclusion in the Work
132
- by You to the Licensor shall be under the terms and conditions of
133
- this License, without any additional terms or conditions.
134
- Notwithstanding the above, nothing herein shall supersede or modify
135
- the terms of any separate license agreement you may have executed
136
- with Licensor regarding such Contributions.
137
-
138
- 6. Trademarks. This License does not grant permission to use the trade
139
- names, trademarks, service marks, or product names of the Licensor,
140
- except as required for reasonable and customary use in describing the
141
- origin of the Work and reproducing the content of the NOTICE file.
142
-
143
- 7. Disclaimer of Warranty. Unless required by applicable law or
144
- agreed to in writing, Licensor provides the Work (and each
145
- Contributor provides its Contributions) on an "AS IS" BASIS,
146
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
- implied, including, without limitation, any warranties or conditions
148
- of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
- PARTICULAR PURPOSE. You are solely responsible for determining the
150
- appropriateness of using or redistributing the Work and assume any
151
- risks associated with Your exercise of permissions under this License.
152
-
153
- 8. Limitation of Liability. In no event and under no legal theory,
154
- whether in tort (including negligence), contract, or otherwise,
155
- unless required by applicable law (such as deliberate and grossly
156
- negligent acts) or agreed to in writing, shall any Contributor be
157
- liable to You for damages, including any direct, indirect, special,
158
- incidental, or consequential damages of any character arising as a
159
- result of this License or out of the use or inability to use the
160
- Work (including but not limited to damages for loss of goodwill,
161
- work stoppage, computer failure or malfunction, or any and all
162
- other commercial damages or losses), even if such Contributor
163
- has been advised of the possibility of such damages.
164
-
165
- 9. Accepting Warranty or Additional Liability. While redistributing
166
- the Work or Derivative Works thereof, You may choose to offer,
167
- and charge a fee for, acceptance of support, warranty, indemnity,
168
- or other liability obligations and/or rights consistent with this
169
- License. However, in accepting such obligations, You may act only
170
- on Your own behalf and on Your sole responsibility, not on behalf
171
- of any other Contributor, and only if You agree to indemnify,
172
- defend, and hold each Contributor harmless for any liability
173
- incurred by, or claims asserted against, such Contributor by reason
174
- of your accepting any such warranty or additional liability.
175
-
176
- END OF TERMS AND CONDITIONS
177
-
178
- APPENDIX: How to apply the Apache License to your work.
179
-
180
- To apply the Apache License to your work, attach the following
181
- boilerplate notice, with the fields enclosed by brackets "[]"
182
- replaced with your own identifying information. (Don't include
183
- the brackets!) The text should be enclosed in the appropriate
184
- comment syntax for the file format. We also recommend that a
185
- file or class name and description of purpose be included on the
186
- same "printed page" as the copyright notice for easier
187
- identification within third-party archives.
188
-
189
- Copyright 2024 Standard Intelligence PBC
190
-
191
- Licensed under the Apache License, Version 2.0 (the "License");
192
- you may not use this file except in compliance with the License.
193
- You may obtain a copy of the License at
194
-
195
- http://www.apache.org/licenses/LICENSE-2.0
196
-
197
- Unless required by applicable law or agreed to in writing, software
198
- distributed under the License is distributed on an "AS IS" BASIS,
199
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
- See the License for the specific language governing permissions and
201
- limitations under the License.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
README.md DELETED
@@ -1,13 +0,0 @@
1
- ---
2
- title: hertz-dev
3
- emoji: ⚡
4
- colorFrom: gray
5
- colorTo: gray
6
- sdk: gradio
7
- sdk_version: 5.4.0
8
- app_file: app.py
9
- pinned: false
10
- short_description: Unofficial demo for hertz-dev
11
- ---
12
-
13
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py DELETED
@@ -1,142 +0,0 @@
1
- import torch as T
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- import torchaudio
5
- from utils import load_ckpt, print_colored
6
- from tokenizer import make_tokenizer
7
- from model import get_hertz_dev_config
8
- import matplotlib.pyplot as plt
9
- import spaces
10
- import gradio as gr
11
-
12
- device = 'cuda' if T.cuda.is_available() else 'cpu'
13
- #T.cuda.set_device(0)
14
- print(f"Using device: {device}")
15
-
16
- audio_tokenizer = make_tokenizer(device)
17
-
18
- TWO_SPEAKER = False
19
-
20
- model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)
21
-
22
- generator = model_config()
23
- generator = generator.eval().to(T.bfloat16).to(device)
24
-
25
-
26
-
27
- ##############
28
- # Load audio
29
-
30
- def load_and_preprocess_audio(audio_path):
31
- gr.Info("Loading and preprocessing audio...")
32
- # Load audio file
33
- audio_tensor, sr = torchaudio.load(audio_path)
34
- gr.Info(f"Loaded audio shape: {audio_tensor.shape}")
35
-
36
- if TWO_SPEAKER:
37
- if audio_tensor.shape[0] == 1:
38
- gr.Info("Converting mono to stereo...")
39
- audio_tensor = audio_tensor.repeat(2, 1)
40
- gr.Info(f"Stereo audio shape: {audio_tensor.shape}")
41
- else:
42
- if audio_tensor.shape[0] == 2:
43
- gr.Info("Converting stereo to mono...")
44
- audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
45
- gr.Info(f"Mono audio shape: {audio_tensor.shape}")
46
-
47
- # Resample to 16kHz if needed
48
- if sr != 16000:
49
- gr.Info(f"Resampling from {sr}Hz to 16000Hz...")
50
- resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
51
- audio_tensor = resampler(audio_tensor)
52
-
53
- # Clip to 5 minutes if needed
54
- max_samples = 16000 * 60 * 5
55
- if audio_tensor.shape[1] > max_samples:
56
- # gr.Info("Clipping audio to 5 minutes...")
57
- raise gr.Erorr("Maximum prompt is 5 minutes")
58
- # audio_tensor = audio_tensor[:, :max_samples]
59
-
60
- duration_seconds = audio_tensor.shape[1] / sr
61
-
62
- gr.Info("Audio preprocessing complete!")
63
- return audio_tensor.unsqueeze(0), duration_seconds
64
-
65
- ##############
66
- # Return audio to gradio
67
-
68
- def display_audio(audio_tensor):
69
- audio_tensor = audio_tensor.cpu().squeeze()
70
- if audio_tensor.ndim == 1:
71
- audio_tensor = audio_tensor.unsqueeze(0)
72
- audio_tensor = audio_tensor.float()
73
-
74
- # Make a waveform plot
75
- # plt.figure(figsize=(4, 1))
76
- # plt.plot(audio_tensor.numpy()[0], linewidth=0.5)
77
- # plt.axis('off')
78
- # plt.show()
79
-
80
- # Make an audio player
81
- return (16000, audio_tensor.numpy())
82
-
83
- def get_completion(encoded_prompt_audio, prompt_len):
84
- prompt_len_seconds = prompt_len / 8
85
- gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s")
86
- with T.autocast(device_type='cuda', dtype=T.bfloat16):
87
- completed_audio_batch = generator.completion(
88
- encoded_prompt_audio,
89
- temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))
90
- use_cache=True)
91
-
92
- completed_audio = completed_audio_batch
93
- print_colored(f"Decoding completion...", "blue")
94
- if TWO_SPEAKER:
95
- decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())
96
- decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())
97
- decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)
98
- else:
99
- decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())
100
- gr.Info(f"Decoded completion shape: {decoded_completion.shape}")
101
-
102
- gr.Info("Preparing audio for playback...")
103
-
104
- audio_tensor = decoded_completion.cpu().squeeze()
105
- if audio_tensor.ndim == 1:
106
- audio_tensor = audio_tensor.unsqueeze(0)
107
- audio_tensor = audio_tensor.float()
108
-
109
- if audio_tensor.abs().max() > 1:
110
- audio_tensor = audio_tensor / audio_tensor.abs().max()
111
-
112
- # return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
113
- return audio_tensor
114
-
115
- @spaces.GPU
116
- def run(audio_path):
117
- prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path)
118
- prompt_len = prompt_len_seconds * 8
119
- gr.Info("Encoding prompt...")
120
- with T.autocast(device_type='cuda', dtype=T.bfloat16):
121
- if TWO_SPEAKER:
122
- encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))
123
- encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))
124
- encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)
125
- else:
126
- encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
127
- gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}")
128
- gr.Info("Prompt encoded successfully!")
129
- # num_completions = 10
130
- completion = get_completion(encoded_prompt_audio, prompt_len)
131
- return display_audio(completion)
132
-
133
-
134
-
135
- with gr.Blocks() as demo:
136
- gr.Markdown("# hertz-dev")
137
- inp = gr.Audio(label="Input Audio", type="filepath", interactive=True)
138
- btn = gr.Button("Continue", variant="primary")
139
- out = gr.Audio(label="Output", interactive=False)
140
- btn.click(run, inputs=inp, outputs=out)
141
-
142
- demo.queue().launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference.ipynb DELETED
@@ -1,236 +0,0 @@
1
- {
2
- "cells": [
3
- {
4
- "cell_type": "code",
5
- "execution_count": null,
6
- "metadata": {},
7
- "outputs": [],
8
- "source": [
9
- "%load_ext autoreload\n",
10
- "%autoreload 2"
11
- ]
12
- },
13
- {
14
- "cell_type": "code",
15
- "execution_count": null,
16
- "metadata": {},
17
- "outputs": [],
18
- "source": [
19
- "import torch as T\n",
20
- "import torch.nn as nn\n",
21
- "import torch.nn.functional as F\n",
22
- "import torchaudio\n",
23
- "from utils import load_ckpt, print_colored\n",
24
- "from tokenizer import make_tokenizer\n",
25
- "from model import get_hertz_dev_config\n",
26
- "import matplotlib.pyplot as plt\n",
27
- "from IPython.display import Audio, display\n",
28
- "\n",
29
- "\n",
30
- "# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n",
31
- "# you need to install PyTorch with the correct CUDA version. Run:\n",
32
- "# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n",
33
- "\n",
34
- "device = 'cuda' if T.cuda.is_available() else 'cpu'\n",
35
- "T.cuda.set_device(0)\n",
36
- "print_colored(f\"Using device: {device}\", \"grey\")"
37
- ]
38
- },
39
- {
40
- "cell_type": "code",
41
- "execution_count": null,
42
- "metadata": {},
43
- "outputs": [],
44
- "source": [
45
- "# If you've already downloaded the model checkpoints, save them in ckpt/.\n",
46
- "# This code will automatically download them if it can't find them.\n",
47
- "audio_tokenizer = make_tokenizer(device)"
48
- ]
49
- },
50
- {
51
- "cell_type": "code",
52
- "execution_count": 7,
53
- "metadata": {},
54
- "outputs": [],
55
- "source": [
56
- "# We have different checkpoints for the single-speaker and two-speaker models\n",
57
- "# Set to True to load and run inference with the two-speaker model\n",
58
- "TWO_SPEAKER = False"
59
- ]
60
- },
61
- {
62
- "cell_type": "code",
63
- "execution_count": null,
64
- "metadata": {},
65
- "outputs": [],
66
- "source": [
67
- "model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)\n",
68
- "\n",
69
- "generator = model_config()\n",
70
- "generator = generator.eval().to(T.bfloat16).to(device)"
71
- ]
72
- },
73
- {
74
- "cell_type": "code",
75
- "execution_count": null,
76
- "metadata": {},
77
- "outputs": [],
78
- "source": [
79
- "def load_and_preprocess_audio(audio_path):\n",
80
- " print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n",
81
- " # Load audio file\n",
82
- " audio_tensor, sr = torchaudio.load(audio_path)\n",
83
- " print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n",
84
- " \n",
85
- " if TWO_SPEAKER:\n",
86
- " if audio_tensor.shape[0] == 1:\n",
87
- " print_colored(\"Converting mono to stereo...\", \"grey\")\n",
88
- " audio_tensor = audio_tensor.repeat(2, 1)\n",
89
- " print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n",
90
- " else:\n",
91
- " if audio_tensor.shape[0] == 2:\n",
92
- " print_colored(\"Converting stereo to mono...\", \"grey\")\n",
93
- " audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n",
94
- " print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n",
95
- " \n",
96
- " # Resample to 16kHz if needed\n",
97
- " if sr != 16000:\n",
98
- " print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n",
99
- " resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n",
100
- " audio_tensor = resampler(audio_tensor)\n",
101
- " \n",
102
- " # Clip to 5 minutes if needed\n",
103
- " max_samples = 16000 * 60 * 5\n",
104
- " if audio_tensor.shape[1] > max_samples:\n",
105
- " print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n",
106
- " audio_tensor = audio_tensor[:, :max_samples]\n",
107
- "\n",
108
- " \n",
109
- " print_colored(\"Audio preprocessing complete!\", \"green\")\n",
110
- " return audio_tensor.unsqueeze(0)\n",
111
- "\n",
112
- "def display_audio(audio_tensor):\n",
113
- " audio_tensor = audio_tensor.cpu().squeeze()\n",
114
- " if audio_tensor.ndim == 1:\n",
115
- " audio_tensor = audio_tensor.unsqueeze(0)\n",
116
- " audio_tensor = audio_tensor.float()\n",
117
- "\n",
118
- " # Make a waveform plot\n",
119
- " plt.figure(figsize=(4, 1))\n",
120
- " plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n",
121
- " plt.axis('off')\n",
122
- " plt.show()\n",
123
- "\n",
124
- " # Make an audio player\n",
125
- " display(Audio(audio_tensor.numpy(), rate=16000))\n",
126
- " print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n",
127
- " \n",
128
- " \n",
129
- "\n",
130
- "# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n",
131
- "prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n",
132
- "display_audio(prompt_audio)\n",
133
- "prompt_len_seconds = 3\n",
134
- "prompt_len = prompt_len_seconds * 8"
135
- ]
136
- },
137
- {
138
- "cell_type": "code",
139
- "execution_count": null,
140
- "metadata": {},
141
- "outputs": [],
142
- "source": [
143
- "print_colored(\"Encoding prompt...\", \"blue\")\n",
144
- "with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
145
- " if TWO_SPEAKER:\n",
146
- " encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n",
147
- " encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n",
148
- " encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n",
149
- " else:\n",
150
- " encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n",
151
- "print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n",
152
- "print_colored(\"Prompt encoded successfully!\", \"green\")"
153
- ]
154
- },
155
- {
156
- "cell_type": "code",
157
- "execution_count": null,
158
- "metadata": {},
159
- "outputs": [],
160
- "source": [
161
- "def get_completion(encoded_prompt_audio, prompt_len):\n",
162
- " prompt_len_seconds = prompt_len / 8\n",
163
- " print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n",
164
- " print_colored(\"Completing audio...\", \"blue\")\n",
165
- " with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
166
- " completed_audio_batch = generator.completion(\n",
167
- " encoded_prompt_audio, \n",
168
- " temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n",
169
- " use_cache=True)\n",
170
- "\n",
171
- " completed_audio = completed_audio_batch\n",
172
- " print_colored(f\"Decoding completion...\", \"blue\")\n",
173
- " if TWO_SPEAKER:\n",
174
- " decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n",
175
- " decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n",
176
- " decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n",
177
- " else:\n",
178
- " decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n",
179
- " print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n",
180
- "\n",
181
- " print_colored(\"Preparing audio for playback...\", \"blue\")\n",
182
- "\n",
183
- " audio_tensor = decoded_completion.cpu().squeeze()\n",
184
- " if audio_tensor.ndim == 1:\n",
185
- " audio_tensor = audio_tensor.unsqueeze(0)\n",
186
- " audio_tensor = audio_tensor.float()\n",
187
- "\n",
188
- " if audio_tensor.abs().max() > 1:\n",
189
- " audio_tensor = audio_tensor / audio_tensor.abs().max()\n",
190
- "\n",
191
- " return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n",
192
- "\n",
193
- "num_completions = 10\n",
194
- "print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n",
195
- "for _ in range(num_completions):\n",
196
- " completion = get_completion(encoded_prompt_audio, prompt_len)\n",
197
- " display_audio(completion)"
198
- ]
199
- },
200
- {
201
- "cell_type": "code",
202
- "execution_count": null,
203
- "metadata": {},
204
- "outputs": [],
205
- "source": []
206
- },
207
- {
208
- "cell_type": "code",
209
- "execution_count": null,
210
- "metadata": {},
211
- "outputs": [],
212
- "source": []
213
- }
214
- ],
215
- "metadata": {
216
- "kernelspec": {
217
- "display_name": ".venv",
218
- "language": "python",
219
- "name": "python3"
220
- },
221
- "language_info": {
222
- "codemirror_mode": {
223
- "name": "ipython",
224
- "version": 3
225
- },
226
- "file_extension": ".py",
227
- "mimetype": "text/x-python",
228
- "name": "python",
229
- "nbconvert_exporter": "python",
230
- "pygments_lexer": "ipython3",
231
- "version": "3.10.12"
232
- }
233
- },
234
- "nbformat": 4,
235
- "nbformat_minor": 2
236
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_client.py DELETED
@@ -1,161 +0,0 @@
1
- # server.py remains the same as before
2
-
3
- # Updated client.py
4
- import asyncio
5
- import websockets
6
- import sounddevice as sd
7
- import numpy as np
8
- import base64
9
- import queue
10
- import argparse
11
- import requests
12
- import time
13
-
14
- class AudioClient:
15
- def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
16
- # Convert ws:// to http:// for the base URL
17
- self.base_url = server_url.replace("ws://", "http://")
18
- self.server_url = f"{server_url}/audio"
19
-
20
- # Set temperatures if provided
21
- if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
22
- self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
23
-
24
- # Initialize queues
25
- self.audio_queue = queue.Queue()
26
- self.output_queue = queue.Queue()
27
-
28
- def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
29
- """Send temperature settings to server"""
30
- params = {}
31
- if token_temp is not None:
32
- params['token_temp'] = token_temp
33
- if categorical_temp is not None:
34
- params['categorical_temp'] = categorical_temp
35
- if gaussian_temp is not None:
36
- params['gaussian_temp'] = gaussian_temp
37
-
38
- response = requests.post(f"{self.base_url}/set_temperature", params=params)
39
- print(response.json()['message'])
40
-
41
- def audio_callback(self, indata, frames, time, status):
42
- """This is called for each audio block"""
43
- if status:
44
- print(status)
45
- # if np.isclose(indata, 0).all():
46
- # raise Exception('Audio input is not working - received all zeros')
47
- # Convert float32 to int16 for efficient transmission
48
- indata_int16 = (indata.copy() * 32767).astype(np.int16)
49
- # indata_int16 = np.zeros_like(indata_int16)
50
- self.audio_queue.put(indata_int16)
51
-
52
- def output_stream_callback(self, outdata, frames, time, status):
53
- """Callback for output stream to get audio data"""
54
- if status:
55
- print(status)
56
-
57
- try:
58
- data = self.output_queue.get_nowait()
59
- data = data.astype(np.float32) / 32767.0
60
- if len(data) < len(outdata):
61
- outdata[:len(data)] = data
62
- outdata[len(data):] = 0
63
- else:
64
- outdata[:] = data[:len(outdata)]
65
- except queue.Empty:
66
- outdata.fill(0)
67
-
68
- async def process_audio(self):
69
- async with websockets.connect(self.server_url) as ws:
70
- while self.running:
71
- if not self.audio_queue.empty():
72
- # Get recorded audio
73
- audio_data = self.audio_queue.get()
74
- print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
75
-
76
- # Convert to base64
77
- audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
78
-
79
- # Send to server
80
- time_sent = time.time()
81
- await ws.send(f"data:audio/raw;base64,{audio_b64}")
82
-
83
- # Receive processed audio
84
- response = await ws.recv()
85
- response = response.split(",")[1]
86
- time_received = time.time()
87
- print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
88
- processed_audio = np.frombuffer(
89
- base64.b64decode(response),
90
- dtype=np.int16
91
- ).reshape(-1, CHANNELS)
92
- print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
93
-
94
- self.output_queue.put(processed_audio)
95
-
96
- def start(self):
97
- self.running = True
98
- # Print audio device information
99
- devices = sd.query_devices()
100
- default_input = sd.query_devices(kind='input')
101
- default_output = sd.query_devices(kind='output')
102
-
103
- print("\nAudio Device Configuration:")
104
- print("-" * 50)
105
- print(f"Default Input Device:\n{default_input}\n")
106
- print(f"Default Output Device:\n{default_output}\n")
107
- print("\nAll Available Devices:")
108
- print("-" * 50)
109
- for i, device in enumerate(devices):
110
- print(f"Device {i}:")
111
- print(f"Name: {device['name']}")
112
- print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}")
113
- print(f"Sample Rates: {device['default_samplerate']}")
114
- print()
115
- input_device = input("Enter the index of the input device or press enter for default: ")
116
- output_device = input("Enter the index of the output device or press enter for default: ")
117
- if input_device == "":
118
- input_device = default_input['index']
119
- if output_device == "":
120
- output_device = default_output['index']
121
- with sd.InputStream(callback=self.audio_callback,
122
- channels=CHANNELS,
123
- samplerate=SAMPLE_RATE,
124
- device=int(input_device),
125
- blocksize=2000), \
126
- sd.OutputStream(callback=self.output_stream_callback,
127
- channels=CHANNELS,
128
- samplerate=SAMPLE_RATE,
129
- blocksize=2000,
130
- device=int(output_device)):
131
-
132
- asyncio.run(self.process_audio())
133
-
134
- def stop(self):
135
- self.running = False
136
-
137
- if __name__ == "__main__":
138
- parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
139
- parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
140
- parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
141
- parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
142
- parser.add_argument('--server', '-s', default="ws://localhost:8000",
143
- help='Server URL (default: ws://localhost:8000)')
144
-
145
- args = parser.parse_args()
146
-
147
- # Audio settings
148
- SAMPLE_RATE = 16000
149
- CHANNELS = 1
150
-
151
- client = AudioClient(
152
- server_url=args.server,
153
- token_temp=args.token_temp,
154
- categorical_temp=args.categorical_temp,
155
- gaussian_temp=args.gaussian_temp
156
- )
157
-
158
- try:
159
- client.start()
160
- except KeyboardInterrupt:
161
- client.stop()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
inference_server.py DELETED
@@ -1,170 +0,0 @@
1
- import time
2
- import numpy as np
3
- from fastapi import FastAPI, WebSocket
4
- from fastapi.middleware.cors import CORSMiddleware
5
- import base64
6
- import uvicorn
7
- import traceback
8
- import numpy as np
9
- import argparse
10
-
11
- import torch as T
12
- import torch.nn.functional as F
13
- import torchaudio
14
-
15
- import os
16
- from typing import Optional
17
-
18
- from utils import print_colored
19
- from model import get_hertz_dev_config
20
-
21
-
22
- argparse = argparse.ArgumentParser()
23
-
24
- argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help="""
25
- We highly recommend making your own prompt based on a conversation between you and another person.
26
- bob_mono.wav seems to work better for two-channel than bob_stereo.wav.
27
- """)
28
- args = argparse.parse_args()
29
-
30
-
31
- device = 'cuda' if T.cuda.is_available() else T.device('cpu')
32
- print_colored(f"Using device: {device}", "grey")
33
-
34
- model_config = get_hertz_dev_config(is_split=True)
35
-
36
- model = model_config()
37
- model = model.eval().bfloat16().to(device)
38
-
39
- app = FastAPI()
40
-
41
- app.add_middleware(
42
- CORSMiddleware,
43
- allow_origins=["*"],
44
- allow_credentials=True,
45
- allow_methods=["*"],
46
- allow_headers=["*"],
47
- )
48
-
49
-
50
- # Hyperparams or something.
51
- SAMPLE_RATE = 16000 # Don't change this
52
- TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it.
53
-
54
- class AudioProcessor:
55
- def __init__(self, model, prompt_path):
56
- self.model = model
57
- self.prompt_path = prompt_path
58
- self.initialize_state(prompt_path)
59
-
60
- def initialize_state(self, prompt_path):
61
- loaded_audio, sr = torchaudio.load(prompt_path)
62
-
63
- if sr != SAMPLE_RATE:
64
- resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
65
- loaded_audio = resampler(loaded_audio)
66
-
67
- if loaded_audio.shape[0] == 1:
68
- loaded_audio = loaded_audio.repeat(2, 1)
69
-
70
- audio_length = loaded_audio.shape[-1]
71
- num_chunks = audio_length // 2000
72
- loaded_audio = loaded_audio[..., :num_chunks * 2000]
73
-
74
- self.loaded_audio = loaded_audio.to(device)
75
-
76
- with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
77
- self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024)
78
- self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
79
- self.prompt_buffer = None
80
- self.prompt_position = 0
81
- self.chunks_until_live = 24
82
- self.initialize_prompt_buffer()
83
- print_colored("AudioProcessor state initialized", "green")
84
-
85
- def initialize_prompt_buffer(self):
86
- self.recorded_audio = self.loaded_audio
87
- prompt_audio = self.loaded_audio.reshape(1, 2, -1)
88
- prompt_audio = prompt_audio[:, :, -48000:].cpu().numpy()
89
- prompt_audio_mono = prompt_audio.mean(axis=1)
90
- self.prompt_buffer = np.array_split(prompt_audio_mono[0], 24)
91
- print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey")
92
-
93
- async def process_audio(self, audio_data):
94
- if self.chunks_until_live > 0:
95
- print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey")
96
- chunk = self.prompt_buffer[24 - self.chunks_until_live]
97
- self.chunks_until_live -= 1
98
-
99
- if self.chunks_until_live == 0:
100
- print_colored("Switching to live processing mode", "green")
101
-
102
- time.sleep(0.05)
103
- return chunk
104
-
105
- audio_tensor = T.from_numpy(audio_data).to(device)
106
- audio_tensor = audio_tensor.reshape(1, 1, -1)
107
- audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1)
108
-
109
- with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
110
- curr_model_audio = self.model.next_audio_from_audio(
111
- audio_tensor,
112
- temps=TEMPS
113
- )
114
- print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}")
115
- self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1)
116
-
117
- self.next_model_audio = curr_model_audio
118
-
119
- return curr_model_audio.float().cpu().numpy()
120
-
121
- def cleanup(self):
122
- print_colored("Cleaning up audio processor...", "blue")
123
- os.makedirs('audio_recordings', exist_ok=True)
124
- torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE)
125
- self.model.deinit_cache()
126
- self.initialize_state(self.prompt_path)
127
- print_colored("Audio processor cleanup complete", "green")
128
-
129
- @app.post("/set_temperature")
130
- async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None):
131
- try:
132
- global TEMPS
133
- TEMPS = (token_temp, (categorical_temp, gaussian_temp))
134
-
135
- print_colored(f"Temperature updated to: {TEMPS}", "green")
136
- return {"message": f"Temperature updated to: {TEMPS}", "status": "success"}
137
- except Exception as e:
138
- print_colored(f"Error setting temperature: {str(e)}", "red")
139
- return {"message": f"Error setting temperature: {str(e)}", "status": "error"}
140
-
141
- @app.websocket("/audio")
142
- async def websocket_endpoint(websocket: WebSocket):
143
- await websocket.accept()
144
- try:
145
- while True:
146
- data = await websocket.receive_text()
147
- audio_data = np.frombuffer(
148
- base64.b64decode(data.split(",")[1]),
149
- dtype=np.int16
150
- )
151
- audio_data = audio_data.astype(np.float32) / 32767.0
152
- processed_audio = await audio_processor.process_audio(audio_data)
153
- processed_audio = (processed_audio * 32767).astype(np.int16)
154
-
155
- processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8')
156
- await websocket.send_text(f"data:audio/raw;base64,{processed_data}")
157
-
158
- except Exception as e:
159
- print_colored(f"WebSocket error: {e}", "red")
160
- print_colored(f"Full traceback:\n{traceback.format_exc()}", "red")
161
- finally:
162
- audio_processor.cleanup()
163
- await websocket.close()
164
-
165
-
166
- audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
167
-
168
- if __name__ == "__main__":
169
- uvicorn.run(app, host="0.0.0.0", port=8000)
170
- print("Server started")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ioblocks.py DELETED
@@ -1,333 +0,0 @@
1
- from __future__ import annotations
2
- from functools import partial
3
- from contextlib import nullcontext
4
- from typing import List, Tuple
5
- from math import ceil
6
-
7
- import torch as T
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- import torch.distributed as dist
11
- from torch import Tensor, int32
12
- from torch.amp import autocast
13
-
14
- from einops import rearrange, pack, unpack
15
-
16
-
17
- from utils import si_module, exists, default, maybe
18
-
19
-
20
- @si_module
21
- class GaussianMixtureIOLayer(nn.Module):
22
- class Config:
23
- latent_dim: int
24
- dim: int
25
- num_components: int
26
-
27
- def __init__(self, c: Config):
28
- super().__init__()
29
- self.latent_dim = c.latent_dim
30
- self.num_components = c.num_components
31
- self.input_projection = nn.Linear(c.latent_dim, c.dim)
32
-
33
- self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
34
- self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
35
- self.fc_weight = nn.Linear(c.dim, c.num_components)
36
-
37
- def _square_plus(self, x):
38
- return (x + T.sqrt(T.square(x) + 4)) / 2
39
-
40
- def input(self, sampled_latents: T.Tensor) -> T.Tensor:
41
- """Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
42
- hidden = self.input_projection(sampled_latents)
43
- return hidden
44
-
45
- def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
46
- """float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
47
- batch_size, seq_len, _ = h.shape
48
-
49
- locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
50
- scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
51
- weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
52
-
53
- return (locs, scales, weights)
54
-
55
- def loss(self, data, dataHat):
56
- locs, scales, weights = dataHat
57
- log_probs = -0.5 * T.sum(
58
- (data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
59
- 2 * T.log(scales) +
60
- T.log(T.tensor(2 * T.pi)),
61
- dim=-1
62
- )
63
- log_weights = F.log_softmax(weights, dim=-1)
64
- return -T.logsumexp(log_weights + log_probs, dim=-1)
65
-
66
-
67
- def temp_sample(self, orig_pdist, temp):
68
- locs, scales, weights = orig_pdist
69
- if temp is None:
70
- component_samples = locs + scales * T.randn_like(scales)
71
- mixture_samples = F.gumbel_softmax(weights, hard=True)
72
- sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
73
- elif isinstance(temp, tuple):
74
- assert len(temp) == 2
75
- categorical_temp, gaussian_temp = temp
76
- component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
77
- mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
78
- sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
79
- else:
80
- component_samples = locs + scales * temp * T.randn_like(scales)
81
- mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
82
- sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
83
- return sampled
84
-
85
-
86
- class GPTOutput(nn.Module):
87
- def __init__(self, dim, vocab_size):
88
- super().__init__()
89
- self.output = nn.Linear(dim, vocab_size, bias=False)
90
-
91
- def forward(self, x):
92
- return self.output(x)
93
-
94
-
95
- # helper functions
96
-
97
- def pack_one(t, pattern):
98
- return pack([t], pattern)
99
-
100
- def unpack_one(t, ps, pattern):
101
- return unpack(t, ps, pattern)[0]
102
-
103
- def first(l):
104
- return l[0]
105
-
106
- def round_up_multiple(num, mult):
107
- return ceil(num / mult) * mult
108
-
109
- def get_code_utilization(codes, codebook_size, get_global=False):
110
- if get_global and dist.is_initialized():
111
- world_size = dist.get_world_size()
112
- else:
113
- world_size = 1
114
-
115
- if world_size > 1:
116
- gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
117
- dist.all_gather(gathered_tokens, codes)
118
- gathered_tokens = T.cat(gathered_tokens, dim=0)
119
- else:
120
- gathered_tokens = codes
121
- unique_tokens = len(T.unique(gathered_tokens))
122
- code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
123
- return code_utilization
124
-
125
- # tensor helpers
126
-
127
- def round_ste(z: Tensor) -> Tensor:
128
- """Round with straight through gradients."""
129
- zhat = z.round()
130
- return z + (zhat - z).detach()
131
-
132
- # main class
133
- # lucidrains fsq
134
- @si_module
135
- class FSQ(nn.Module):
136
- @property
137
- def needs_float32_params(self):
138
- return True
139
-
140
- class Config:
141
- levels: List[int]
142
- dim: int | None = None
143
- num_codebooks: int = 1
144
- keep_num_codebooks_dim: bool | None = None
145
- scale: float | None = None
146
- allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
147
- channel_first: bool = False
148
- projection_has_bias: bool = True
149
- return_indices: bool = True
150
- force_quantization_f32: bool = True
151
- use_rms: bool = False
152
-
153
- def __init__(self, c: Config):
154
- super().__init__()
155
- _levels = T.tensor(c.levels, dtype=int32)
156
- self.register_buffer("_levels", _levels, persistent = False)
157
-
158
- _basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
159
- self.register_buffer("_basis", _basis, persistent = False)
160
-
161
- self.scale = c.scale
162
-
163
- codebook_dim = len(c.levels)
164
- self.codebook_dim = codebook_dim
165
-
166
- effective_codebook_dim = codebook_dim * c.num_codebooks
167
- self.num_codebooks = c.num_codebooks
168
-
169
- self.allowed_dtypes = []
170
- for dtype_str in c.allowed_dtypes:
171
- if hasattr(T, dtype_str):
172
- self.allowed_dtypes.append(getattr(T, dtype_str))
173
- else:
174
- raise ValueError(f"Invalid dtype string: {dtype_str}")
175
-
176
- self.effective_codebook_dim = effective_codebook_dim
177
-
178
- keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
179
- assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
180
- self.keep_num_codebooks_dim = keep_num_codebooks_dim
181
-
182
- self.dim = default(c.dim, len(_levels) * c.num_codebooks)
183
-
184
- self.channel_first = c.channel_first
185
-
186
- has_projections = self.dim != effective_codebook_dim
187
- self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
188
- self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
189
-
190
- self.has_projections = has_projections
191
-
192
- self.return_indices = c.return_indices
193
- if c.return_indices:
194
- self.codebook_size = self._levels.prod().item()
195
- implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
196
- self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
197
-
198
- self.allowed_dtypes = c.allowed_dtypes
199
- self.force_quantization_f32 = c.force_quantization_f32
200
-
201
- self.latent_loss = None
202
-
203
- def latent_metric(self, codes, get_global=False):
204
- return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
205
-
206
- def repr_from_latent(self, latent):
207
- return self.indices_to_codes(latent)
208
-
209
- def bound(self, z, eps: float = 1e-3):
210
- """ Bound `z`, an array of shape (..., d). """
211
- half_l = (self._levels - 1) * (1 + eps) / 2
212
- offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
213
- shift = (offset / half_l).atanh()
214
- return (z + shift).tanh() * half_l - offset
215
-
216
- def quantize(self, z):
217
- """ Quantizes z, returns quantized zhat, same shape as z. """
218
- quantized = round_ste(self.bound(z))
219
- half_width = self._levels // 2 # Renormalize to [-1, 1].
220
- return quantized / half_width
221
-
222
- def _scale_and_shift(self, zhat_normalized):
223
- half_width = self._levels // 2
224
- return (zhat_normalized * half_width) + half_width
225
-
226
- def _scale_and_shift_inverse(self, zhat):
227
- half_width = self._levels // 2
228
- return (zhat - half_width) / half_width
229
-
230
- def _indices_to_codes(self, indices):
231
- level_indices = self.indices_to_level_indices(indices)
232
- codes = self._scale_and_shift_inverse(level_indices)
233
- return codes
234
-
235
- def codes_to_indices(self, zhat):
236
- """ Converts a `code` to an index in the codebook. """
237
- assert zhat.shape[-1] == self.codebook_dim
238
- zhat = self._scale_and_shift(zhat)
239
- return (zhat * self._basis).sum(dim=-1).to(int32)
240
-
241
- def indices_to_level_indices(self, indices):
242
- """ Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
243
- indices = rearrange(indices, '... -> ... 1')
244
- codes_non_centered = (indices // self._basis) % self._levels
245
- return codes_non_centered
246
-
247
- def indices_to_codes(self, indices):
248
- """ Inverse of `codes_to_indices`. """
249
- assert exists(indices)
250
-
251
- is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
252
-
253
- codes = self._indices_to_codes(indices)
254
-
255
- if self.keep_num_codebooks_dim:
256
- codes = rearrange(codes, '... c d -> ... (c d)')
257
-
258
- codes = self.project_out(codes)
259
-
260
- if is_img_or_video or self.channel_first:
261
- codes = rearrange(codes, 'b ... d -> b d ...')
262
-
263
- return codes
264
-
265
- # @autocast(device_type='cuda', enabled = False)
266
- def forward(self, z, return_codes=False):
267
- """
268
- einstein notation
269
- b - batch
270
- n - sequence (or flattened spatial dimensions)
271
- d - feature dimension
272
- c - number of codebook dim
273
- """
274
-
275
- is_img_or_video = z.ndim >= 4
276
- need_move_channel_last = is_img_or_video or self.channel_first
277
-
278
- # standardize image or video into (batch, seq, dimension)
279
-
280
- if need_move_channel_last:
281
- z = rearrange(z, 'b d ... -> b ... d')
282
- z, ps = pack_one(z, 'b * d')
283
-
284
- assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
285
-
286
- z = self.project_in(z)
287
-
288
- z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
289
-
290
- # whether to force quantization step to be full precision or not
291
-
292
- force_f32 = self.force_quantization_f32
293
- quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
294
-
295
- with quantization_context():
296
- orig_dtype = z.dtype
297
-
298
- if force_f32 and orig_dtype not in self.allowed_dtypes:
299
- z = z.float()
300
-
301
- codes = self.quantize(z)
302
-
303
- # returning indices could be optional
304
-
305
- indices = None
306
-
307
- if self.return_indices:
308
- indices = self.codes_to_indices(codes)
309
-
310
- codes = rearrange(codes, 'b n c d -> b n (c d)')
311
-
312
- codes = codes.type(orig_dtype)
313
-
314
- # project out
315
- if return_codes:
316
- return codes, indices
317
-
318
- out = self.project_out(codes)
319
-
320
- # reconstitute image or video dimensions
321
-
322
- if need_move_channel_last:
323
- out = unpack_one(out, ps, 'b * d')
324
- out = rearrange(out, 'b ... d -> b d ...')
325
-
326
- indices = maybe(unpack_one)(indices, ps, 'b * c')
327
-
328
- if not self.keep_num_codebooks_dim and self.return_indices:
329
- indices = maybe(rearrange)(indices, '... 1 -> ...')
330
-
331
- # return quantized output and indices
332
-
333
- return out, indices
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
model.py DELETED
@@ -1,441 +0,0 @@
1
- from typing import Optional, Tuple
2
-
3
- import torch as T
4
- import torch.nn as nn
5
- import torch.nn.functional as F
6
-
7
- from ioblocks import GaussianMixtureIOLayer, FSQ
8
-
9
- from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
10
- from tokenizer import make_tokenizer
11
-
12
-
13
- from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
14
- from utils import load_ckpt
15
-
16
-
17
- @si_module
18
- class LatentQuantizer(nn.Module):
19
- class Config:
20
- compressor_config: Optional[FSQ.Config] = None
21
-
22
- dim: Optional[int] = None
23
- ff_dim: Optional[int] = None
24
- input_dim: int = None
25
-
26
- from_pretrained: Optional[Tuple[str, str]] = None
27
-
28
- def __init__(self, c: Config):
29
- super().__init__()
30
-
31
- if exists(c.from_pretrained):
32
- checkpoint = load_ckpt(*c.from_pretrained)
33
- else:
34
- assert exists(c.compressor_config), f'hmm {c}'
35
-
36
- self.compressor = c.compressor_config()
37
- self.ffnn = FFNN(c.dim, c.ff_dim)
38
- self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
39
-
40
- if exists(c.from_pretrained):
41
- self.load_state_dict(checkpoint)
42
-
43
- @T.no_grad()
44
- def forward(self, x, return_latent=False, known_latent=None):
45
- """
46
- x: (B, S, D)
47
- """
48
- if exists(known_latent):
49
- return self.compressor.indices_to_codes(known_latent)
50
-
51
- x = self.input(x)
52
- x = self.ffnn(x)
53
- x, tokens = self.compressor(x)
54
-
55
- if return_latent:
56
- return x, tokens
57
- return x
58
-
59
-
60
- @si_module
61
- class TransformerVAE(nn.Module):
62
- class Config:
63
- io_config: Optional[GaussianMixtureIOLayer.Config] = None
64
- stack_config: Optional[Stack.Config] = None
65
- quantizer_config: Optional[LatentQuantizer.Config] = None
66
-
67
- plex_layer: int = None
68
- plex_roll: int = 1
69
- split: bool = True
70
-
71
- from_pretrained: Optional[Tuple[str, str]] = None
72
-
73
- def __init__(self, c: Config):
74
- super().__init__()
75
-
76
- if exists(c.from_pretrained):
77
- checkpoint = load_ckpt(*c.from_pretrained)
78
- else:
79
- assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
80
-
81
- self.io = c.io_config()
82
- self.stack = c.stack_config()
83
-
84
- self.plex_layer = c.stack_config.layers//2
85
- self.plex_roll = c.plex_roll
86
- self.plex_dim = c.quantizer_config.dim
87
-
88
- assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
89
- self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
90
- self.out_norm = Norm(c.stack_config.dim)
91
-
92
- if c.split:
93
- self.io2 = c.io_config()
94
- self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
95
-
96
- self.io2.fc_loc = None
97
- self.io2.fc_scale = None
98
- self.io2.fc_weight = None
99
-
100
- kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
101
- head_dim = c.stack_config.dim // c.stack_config.n_head
102
- self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
103
- cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
104
- self.cache_shape = cache_shape
105
- self.cache = [None] * self.cache_num_layers
106
-
107
- if exists(c.from_pretrained):
108
- result = self.load_state_dict(checkpoint, strict=False)
109
- print0_colored(result, 'yellow')
110
-
111
- self.quantizer = c.quantizer_config().eval()
112
- self.quantizer.requires_grad = False
113
-
114
- @T.no_grad()
115
- def quantize(self, x):
116
- if self.c.split:
117
- x1, x2 = x.chunk(2, dim=-1)
118
- with T.autocast(device_type='cuda', dtype=T.bfloat16):
119
- quantized1 = self.quantizer(x1)
120
- quantized2 = self.quantizer(x2)
121
- return quantized1, quantized2
122
- else:
123
- with T.autocast(device_type='cuda', dtype=T.bfloat16):
124
- return self.quantizer(x)
125
-
126
- @T.no_grad()
127
- def untokenize(self, token_data):
128
- return self.quantizer(None, known_latent=token_data)
129
-
130
- def init_cache(self, bsize, device, dtype, length:int=None):
131
- cache_shape = self.cache_shape.copy()
132
- cache_shape[1] = length or cache_shape[1]
133
- self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
134
-
135
- def deinit_cache(self):
136
- self.cache = [None] * self.cache_num_layers
137
-
138
- @T.no_grad()
139
- def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
140
- if self.c.split:
141
- x1, x2 = data.chunk(2, dim=-1)
142
- x = self.io.input(x1) + self.io2.input(x2)
143
- else:
144
- x = self.io.input(data)
145
-
146
- cache_idx = 0
147
- for l, layer in enumerate(self.stack.layers):
148
- if l == self.plex_layer:
149
- if self.c.split:
150
- plex1, plex2 = self.quantize(data)
151
- plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
152
- plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
153
- if exists(next_tokens):
154
- plex1[:, -1:] = self.untokenize(next_tokens[0])
155
- plex2[:, -1:] = self.untokenize(next_tokens[1])
156
- x1 = x + self.plex_projection(plex1)
157
- x2 = x + self.plex_projection2(plex2)
158
- else:
159
- plex = self.quantize(data)
160
- plex = T.roll(plex, -self.c.plex_roll, dims=1)
161
- if exists(next_tokens):
162
- plex[:, -1:] = self.untokenize(next_tokens)
163
- x = x + self.plex_projection(plex)
164
-
165
- if l < self.plex_layer:
166
- x = layer(x, kv=self.cache[l])
167
- else:
168
- if self.c.split:
169
- x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
170
- cache_idx += 1
171
- x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
172
- cache_idx += 1
173
- else:
174
- x = layer(x, kv=self.cache[l])
175
-
176
- with T.autocast(device_type='cuda', dtype=T.bfloat16):
177
- if self.c.split:
178
- x1, x2 = self.out_norm(x1), self.out_norm(x2)
179
- out1, out2 = self.io.output(x1), self.io.output(x2)
180
- else:
181
- x = self.out_norm(x)
182
- out = self.io.output(x)
183
-
184
- if isnt(temps):
185
- if self.c.split:
186
- return out1, out2
187
- else:
188
- return out
189
- else:
190
- if self.c.split:
191
- next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
192
- next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
193
- next_data = T.cat([next_data1, next_data2], dim=-1)
194
- return next_data
195
- else:
196
- next_data = self.io.temp_sample(out, temps)[:, -1:, :]
197
- return next_data
198
-
199
- @si_module
200
- class HertzDevModel(nn.Module):
201
- class Config:
202
- dim: int
203
- vocab_size: int
204
- stack_config: Optional[Stack.Config] = None
205
- latent_size: int = 32
206
-
207
- split: bool = True
208
-
209
- quantizer_config: Optional[LatentQuantizer.Config] = None
210
- resynthesizer_config: Optional[TransformerVAE.Config] = None
211
-
212
- from_pretrained: Optional[Tuple[str, str]] = None
213
-
214
- def __init__(self, c: Config):
215
- super().__init__()
216
-
217
- if exists(c.from_pretrained):
218
- checkpoint = load_ckpt(*c.from_pretrained)
219
- else:
220
- assert (exists(c.stack_config)), f'hmm {c}'
221
-
222
- self.input = nn.Linear(c.latent_size, c.dim)
223
- if self.c.split:
224
- self.input2 = nn.Linear(c.latent_size, c.dim)
225
-
226
- self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
227
-
228
- self.layers = nn.ModuleList([
229
- PerfBlock(
230
- dim=c.stack_config.dim,
231
- layer_id=l,
232
- n_head=c.stack_config.n_head,
233
- kv_heads=c.stack_config.kv_heads,
234
- ff_dim=c.stack_config.ff_dim,
235
- eps=c.stack_config.eps,
236
- shape_rotator=self.shape_rotator,
237
- ) for l in range(c.stack_config.layers)
238
- ])
239
-
240
- self.output = GPTOutput(c.dim, c.vocab_size)
241
- if self.c.split:
242
- self.output2 = GPTOutput(c.dim, c.vocab_size)
243
-
244
- self.cache = [None] * c.stack_config.layers
245
- self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
246
- self.head_dim = c.stack_config.dim // c.stack_config.n_head
247
-
248
- if exists(c.from_pretrained):
249
- result = self.load_state_dict(checkpoint, strict=False)
250
- print0_colored(result, 'yellow')
251
-
252
- self.resynthesizer = c.resynthesizer_config().eval()
253
- self.resynthesizer.requires_grad = False
254
-
255
- self.audio_tokenizer = make_tokenizer(device='cpu')
256
- self.audio_cache = None
257
- self.audio_latent_cache = None
258
- self.use_audio_cache = False
259
-
260
- @T.no_grad()
261
- def tokenize(self, audio_data):
262
- orig_audio_shape = audio_data.shape
263
- if exists(self.audio_cache):
264
- audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
265
- self.audio_cache = audio_data[..., -(6*16_000):]
266
- elif self.use_audio_cache:
267
- self.audio_cache = audio_data[..., -(6*16_000):]
268
-
269
- if audio_data.shape[1] == 2:
270
- enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
271
- enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
272
- return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
273
- else:
274
- return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
275
-
276
- @T.no_grad()
277
- def untokenize(self, token_data):
278
- if exists(self.audio_latent_cache):
279
- token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
280
- self.audio_latent_cache = token_data[:, -(6*8):]
281
- elif self.use_audio_cache:
282
- self.audio_latent_cache = token_data[:, -(6*8):]
283
-
284
- if token_data.shape[-1] == 2*self.c.latent_size:
285
- dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
286
- dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
287
- return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
288
- else:
289
- return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
290
-
291
- def init_cache(self, bsize, device, dtype, length:int=None):
292
- cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
293
- self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
294
- self.resynthesizer.init_cache(bsize, device, dtype, length)
295
- self.use_audio_cache = True
296
-
297
- def deinit_cache(self):
298
- self.cache = [None] * len(self.layers)
299
- self.resynthesizer.deinit_cache()
300
- self.audio_cache = None
301
- self.audio_latent_cache = None
302
- self.use_audio_cache = False
303
-
304
- @T.no_grad()
305
- def forward(self, data):
306
- if self.c.split:
307
- x1, x2 = data.chunk(2, dim=-1)
308
- x = self.input(x1) + self.input2(x2)
309
- else:
310
- x = self.input(data)
311
-
312
- for l, layer in enumerate(self.layers):
313
- x = layer(x, kv=self.cache[l])
314
-
315
- if self.c.split:
316
- return self.output(x), self.output2(x)
317
- else:
318
- return self.output(x)
319
-
320
- @T.no_grad()
321
- def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
322
- latents_in = self.tokenize(audio_data)
323
- next_latents = self.next_latent(latents_in, temps)
324
- next_model_latent = next_latents[..., self.c.latent_size:]
325
- audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
326
- return audio_decoded
327
-
328
-
329
- @T.no_grad()
330
- def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
331
-
332
- if self.c.split:
333
- logits1, logits2 = self.forward(model_input)
334
- next_logits1 = logits1[:, -1]
335
- next_logits2 = logits2[:, -1]
336
- next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
337
- next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
338
-
339
- next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
340
- else:
341
- logits = self.forward(model_input)
342
- next_logits = logits[:, -1]
343
- next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
344
-
345
- next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
346
-
347
- return next_input
348
-
349
-
350
- @T.no_grad()
351
- def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
352
- """
353
- only accepts latent-space data.
354
- """
355
- if use_cache:
356
- self.init_cache(data.shape[0], data.device, T.bfloat16)
357
-
358
- next_input = generated = data
359
-
360
- target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
361
-
362
- for _ in tqdm0(range(data.shape[1], target_len)):
363
- model_input = next_input if use_cache else generated
364
-
365
- next_input = self.next_latent(model_input, temps)
366
-
367
- generated = T.cat([generated, next_input], dim=1)
368
-
369
- if use_cache:
370
- self.deinit_cache()
371
- return generated
372
-
373
-
374
-
375
- def get_hertz_dev_config(is_split=True):
376
- if is_split:
377
- checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
378
- else:
379
- checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
380
-
381
- quantizer_config=LatentQuantizer.Config(
382
- from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
383
- compressor_config=FSQ.Config(
384
- levels=[8,8,8,8,8],
385
- dim=2048,
386
- num_codebooks=1,
387
- keep_num_codebooks_dim=None,
388
- scale=None,
389
- allowed_dtypes=['float32', 'float64', 'bfloat16'],
390
- channel_first=False,
391
- projection_has_bias=True,
392
- return_indices=True,
393
- force_quantization_f32=True,
394
- use_rms=False
395
- ),
396
- dim=2048,
397
- ff_dim=8192,
398
- input_dim=32
399
- )
400
-
401
- resynthesizer_config=TransformerVAE.Config(
402
- io_config=GaussianMixtureIOLayer.Config(
403
- latent_dim=32,
404
- dim=4096,
405
- num_components=8,
406
- ),
407
- stack_config=Stack.Config(
408
- layers=8,
409
- dim=4096,
410
- seq_len=8192,
411
- n_head=16,
412
- ff_dim=11008,
413
- kv_heads=16,
414
- eps=1e-5,
415
- theta=10_000
416
- ),
417
- quantizer_config=quantizer_config,
418
- plex_layer=None,
419
- plex_roll=1,
420
- split=is_split,
421
- from_pretrained=checkpoints[0],
422
- )
423
-
424
- return HertzDevModel.Config(
425
- dim=4096,
426
- vocab_size=32_768,
427
- stack_config=Stack.Config(
428
- layers=32,
429
- dim=4096,
430
- seq_len=2048,
431
- n_head=32,
432
- ff_dim=None,
433
- kv_heads=None,
434
- eps=1e-5,
435
- theta=10_000,
436
- ),
437
- quantizer_config=quantizer_config,
438
- resynthesizer_config=resynthesizer_config,
439
- split=is_split,
440
- from_pretrained=checkpoints[1],
441
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
prompts/bob_duo.wav DELETED
Binary file (554 kB)
 
prompts/bob_mono.wav DELETED
Binary file (560 kB)
 
prompts/countdown_mono.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5f22399fd8039c043758ef527a588f608394f761395a437ffc02077fa0dca517
3
- size 1664746
 
 
 
 
prompts/toaskanymore.wav DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:5fa50bbee5e0f5a952f1b3d8afa2c051e00530b1c9040fcf45e7803a89ce6881
3
- size 1088044
 
 
 
 
requirements.txt DELETED
@@ -1,14 +0,0 @@
1
- torch
2
- torchaudio
3
- einops
4
- tqdm
5
- IPython
6
- numpy
7
- soundfile
8
- websockets
9
- requests
10
- sounddevice
11
- matplotlib
12
- fastapi
13
- uvicorn
14
- argparse
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer.py DELETED
@@ -1,581 +0,0 @@
1
- import math
2
- from dataclasses import dataclass
3
- from typing import Union, Tuple, Literal
4
-
5
- import torch as T
6
- import torch.nn as nn
7
- from torch.nn.utils.parametrizations import weight_norm
8
-
9
- from utils import load_ckpt
10
- from utils.interp import print_colored
11
- from utils import si_module, get_activation
12
-
13
-
14
-
15
- # Adapted from https://github.com/facebookresearch/AudioDec
16
-
17
- def Conv1d1x1(in_channels, out_channels, bias=True):
18
- return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
19
-
20
-
21
- class NonCausalConv1d(nn.Module):
22
- """1D noncausal convloution w/ 2-sides padding."""
23
-
24
- def __init__(
25
- self,
26
- in_channels,
27
- out_channels,
28
- kernel_size,
29
- stride=1,
30
- padding=-1,
31
- dilation=1,
32
- groups=1,
33
- bias=True):
34
- super().__init__()
35
- self.in_channels = in_channels
36
- self.out_channels = out_channels
37
- self.kernel_size = kernel_size
38
- if padding < 0:
39
- padding = (kernel_size - 1) // 2 * dilation
40
- self.dilation = dilation
41
- self.conv = nn.Conv1d(
42
- in_channels=in_channels,
43
- out_channels=out_channels,
44
- kernel_size=kernel_size,
45
- stride=stride,
46
- padding=padding,
47
- dilation=dilation,
48
- groups=groups,
49
- bias=bias,
50
- )
51
-
52
- def forward(self, x):
53
- """
54
- Args:
55
- x (Tensor): Float tensor variable with the shape (B, C, T).
56
- Returns:
57
- Tensor: Float tensor variable with the shape (B, C, T).
58
- """
59
- x = self.conv(x)
60
- return x
61
-
62
-
63
- class NonCausalConvTranspose1d(nn.Module):
64
- """1D noncausal transpose convloution."""
65
-
66
- def __init__(
67
- self,
68
- in_channels,
69
- out_channels,
70
- kernel_size,
71
- stride,
72
- padding=-1,
73
- output_padding=-1,
74
- groups=1,
75
- bias=True,
76
- ):
77
- super().__init__()
78
- if padding < 0:
79
- padding = (stride+1) // 2
80
- if output_padding < 0:
81
- output_padding = 1 if stride % 2 else 0
82
- self.deconv = nn.ConvTranspose1d(
83
- in_channels=in_channels,
84
- out_channels=out_channels,
85
- kernel_size=kernel_size,
86
- stride=stride,
87
- padding=padding,
88
- output_padding=output_padding,
89
- groups=groups,
90
- bias=bias,
91
- )
92
-
93
- def forward(self, x):
94
- """
95
- Args:
96
- x (Tensor): Float tensor variable with the shape (B, C, T).
97
- Returns:
98
- Tensor: Float tensor variable with the shape (B, C', T').
99
- """
100
- x = self.deconv(x)
101
- return x
102
-
103
-
104
- class CausalConv1d(NonCausalConv1d):
105
- def __init__(
106
- self,
107
- in_channels,
108
- out_channels,
109
- kernel_size,
110
- stride=1,
111
- dilation=1,
112
- groups=1,
113
- bias=True
114
- ):
115
- super(CausalConv1d, self).__init__(
116
- in_channels=in_channels,
117
- out_channels=out_channels,
118
- kernel_size=kernel_size,
119
- stride=stride,
120
- padding=0,
121
- dilation=dilation,
122
- groups=groups,
123
- bias=bias,
124
- )
125
- self.stride = stride
126
- self.pad_length = (kernel_size - 1) * dilation
127
- def forward(self, x):
128
- pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
129
- x = pad(x)
130
- return self.conv(x)
131
-
132
-
133
- class CausalConvTranspose1d(NonCausalConvTranspose1d):
134
- def __init__(
135
- self,
136
- in_channels,
137
- out_channels,
138
- kernel_size,
139
- stride,
140
- bias=True,
141
- pad_buffer=None,
142
- ):
143
- super(CausalConvTranspose1d, self).__init__(
144
- in_channels=in_channels,
145
- out_channels=out_channels,
146
- kernel_size=kernel_size,
147
- stride=stride,
148
- padding=0,
149
- output_padding=0,
150
- bias=bias,
151
- )
152
- self.stride = stride
153
- self.pad_length = (math.ceil(kernel_size/stride) - 1)
154
- if pad_buffer is None:
155
- pad_buffer = T.zeros(1, in_channels, self.pad_length)
156
- self.register_buffer("pad_buffer", pad_buffer)
157
-
158
- def forward(self, x):
159
- pad = nn.ReplicationPad1d((self.pad_length, 0))
160
- x = pad(x)
161
- return self.deconv(x)[:, :, self.stride : -self.stride]
162
-
163
- def inference(self, x):
164
- x = T.cat((self.pad_buffer, x), -1)
165
- self.pad_buffer = x[:, :, -self.pad_length:]
166
- return self.deconv(x)[:, :, self.stride : -self.stride]
167
-
168
- def reset_buffer(self):
169
- self.pad_buffer.zero_()
170
-
171
-
172
- class NonCausalResUnit(nn.Module):
173
- def __init__(
174
- self,
175
- in_channels,
176
- out_channels,
177
- kernel_size=7,
178
- dilation=1,
179
- bias=False,
180
- ):
181
- super().__init__()
182
- self.activation = nn.ELU()
183
- self.conv1 = NonCausalConv1d(
184
- in_channels=in_channels,
185
- out_channels=out_channels,
186
- kernel_size=kernel_size,
187
- stride=1,
188
- dilation=dilation,
189
- bias=bias,
190
- )
191
- self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
192
-
193
- def forward(self, x):
194
- y = self.conv1(self.activation(x))
195
- y = self.conv2(self.activation(y))
196
- return x + y
197
-
198
-
199
- class CausalResUnit(NonCausalResUnit):
200
- def __init__(
201
- self,
202
- in_channels,
203
- out_channels,
204
- kernel_size=7,
205
- dilation=1,
206
- bias=False,
207
- ):
208
- super(CausalResUnit, self).__init__(
209
- in_channels=in_channels,
210
- out_channels=out_channels,
211
- kernel_size=kernel_size,
212
- dilation=dilation,
213
- bias=bias,
214
- )
215
- self.conv1 = CausalConv1d(
216
- in_channels=in_channels,
217
- out_channels=out_channels,
218
- kernel_size=kernel_size,
219
- stride=1,
220
- dilation=dilation,
221
- bias=bias,
222
- )
223
-
224
- def inference(self, x):
225
- y = self.conv1.inference(self.activation(x))
226
- y = self.conv2(self.activation(y))
227
- return x + y
228
-
229
-
230
- class ResNetBlock(nn.Module):
231
- def __init__(self,
232
- in_channels,
233
- out_channels,
234
- stride,
235
- kernel_size=7,
236
- dilations=(1, 3, 9),
237
- bias=True,
238
- mode='encoder',
239
- ):
240
- super().__init__()
241
- assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
242
-
243
- self.mode = mode
244
- self.stride = stride
245
-
246
- ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
247
-
248
- res_channels = in_channels if mode == 'encoder' else out_channels
249
-
250
- res_units = [CausalResUnit(
251
- res_channels,
252
- res_channels,
253
- kernel_size=kernel_size,
254
- dilation=dilation,
255
- ) for dilation in dilations]
256
-
257
- if in_channels == out_channels:
258
- if mode == 'encoder':
259
- self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
260
- if mode == 'decoder':
261
- self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
262
- conv_unit = nn.Conv1d(
263
- in_channels=in_channels,
264
- out_channels=out_channels,
265
- kernel_size=1,
266
- bias=bias,
267
- ) if in_channels != out_channels else nn.Identity()
268
- else:
269
- conv_unit = ConvUnit(
270
- in_channels=in_channels,
271
- out_channels=out_channels,
272
- kernel_size=(2 * stride),
273
- stride=stride,
274
- bias=bias,
275
- )
276
-
277
- if mode == 'encoder':
278
- if in_channels == out_channels:
279
- self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
280
- else:
281
- self.res_block = nn.Sequential(*res_units, conv_unit)
282
- elif mode == 'decoder':
283
- if in_channels == out_channels:
284
- self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
285
- else:
286
- self.res_block = nn.Sequential(conv_unit, *res_units)
287
-
288
- def forward(self, x):
289
- out = x
290
- for unit in self.res_block:
291
- out = unit(out)
292
- return out
293
-
294
- def inference(self, x):
295
- for unit in self.res_block:
296
- x = unit.inference(x)
297
- return x
298
-
299
-
300
-
301
-
302
- @si_module
303
- class ResNetStack(nn.Module):
304
- """
305
- ResNet encoder or decoder stack. Channel ratios
306
- and strides take the default order of from
307
- data/io-layer, to the middle of the model.
308
- """
309
- class Config:
310
- input_channels: int = 1
311
- output_channels: int = 1
312
- encode_channels: int = 32
313
- decode_channel_multiplier: int = 1
314
- latent_dim: int = None
315
- kernel_size: int = 7
316
- bias: bool = True
317
- channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
318
- strides: Tuple[int, ...] = (3, 4, 5, 5)
319
- mode: Literal['encoder', 'decoder'] = 'encoder'
320
-
321
- def __init__(self, c: Config):
322
- super().__init__()
323
- assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
324
-
325
- self.mode = c.mode
326
-
327
- assert len(c.channel_ratios) == len(c.strides)
328
- channel_ratios = (1,) + c.channel_ratios
329
- strides = c.strides
330
- self.middle_channels = c.encode_channels * channel_ratios[-1]
331
- if c.mode == 'decoder':
332
- channel_ratios = tuple(reversed(channel_ratios))
333
- strides = tuple(reversed(strides))
334
-
335
- self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
336
- res_blocks = [ResNetBlock(
337
- c.encode_channels * channel_ratios[s_idx] * self.multiplier,
338
- c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
339
- stride,
340
- kernel_size=c.kernel_size,
341
- bias=c.bias,
342
- mode=c.mode,
343
- ) for s_idx, stride in enumerate(strides)]
344
-
345
- data_conv = CausalConv1d(
346
- in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
347
- out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
348
- kernel_size=c.kernel_size,
349
- stride=1,
350
- bias=False,
351
- )
352
-
353
- if c.mode == 'encoder':
354
- self.res_stack = nn.Sequential(data_conv, *res_blocks)
355
- elif c.mode == 'decoder':
356
- self.res_stack = nn.Sequential(*res_blocks, data_conv)
357
-
358
- if c.latent_dim is not None:
359
- self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
360
- if self.multiplier != 1:
361
- self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
362
-
363
- def forward(self, x, return_feats=False):
364
- if self.c.latent_dim is not None and self.mode == 'decoder':
365
- x = self.latent_proj(x)
366
- if self.multiplier != 1:
367
- x = self.multiplier_proj(x)
368
-
369
- feats = []
370
- for block in self.res_stack:
371
- x = block(x)
372
- if return_feats:
373
- feats.append(x)
374
- if self.c.latent_dim is not None and self.mode == 'encoder':
375
- x = self.latent_proj(x)
376
- if return_feats:
377
- feats.append(x)
378
- if return_feats:
379
- return feats
380
- return x
381
-
382
- def inference(self, x):
383
- for block in self.res_stack:
384
- x = block.inference(x)
385
- return x
386
-
387
- def reset_buffer(self):
388
- def _reset_buffer(m):
389
- if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
390
- m.reset_buffer()
391
- self.apply(_reset_buffer)
392
-
393
- def reset_parameters(self):
394
- def _reset_parameters(m):
395
- if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
396
- m.weight.data.normal_(0.0, 0.01)
397
-
398
- self.apply(_reset_parameters)
399
-
400
-
401
- def apply_weight_norm(self):
402
- def _apply_weight_norm(m):
403
- if isinstance(m, nn.Conv1d) or isinstance(
404
- m, nn.ConvTranspose1d
405
- ):
406
- nn.utils.parametrizations.weight_norm(m)
407
-
408
- self.apply(_apply_weight_norm)
409
-
410
-
411
- def remove_weight_norm(self):
412
- def _remove_weight_norm(m):
413
- try:
414
- print(m)
415
- nn.utils.remove_weight_norm(m)
416
- except ValueError: # this module didn't have weight norm
417
- return
418
-
419
- self.apply(_remove_weight_norm)
420
-
421
-
422
-
423
- @si_module
424
- class GaussianZ(nn.Module):
425
- class Config:
426
- dim: int
427
- latent_dim: int
428
- bias: bool = False
429
- use_weight_norm: bool = False
430
-
431
- def __init__(self, c: Config):
432
- super().__init__()
433
-
434
- self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
435
- self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
436
-
437
- if c.use_weight_norm:
438
- self.proj_in = weight_norm(self.proj_in)
439
- self.proj_out = weight_norm(self.proj_out)
440
-
441
- def reparam(self, mu, logvar):
442
- std = T.exp(logvar / 2)
443
- eps = T.randn_like(std)
444
- return mu + eps * std
445
-
446
- def kl_divergence(self, mu, logvar):
447
- return T.mean(-0.5 * T.sum(
448
- 1 + logvar - mu.pow(2) - logvar.exp(),
449
- dim=(1, 2))
450
- )
451
-
452
- def repr_from_latent(self, latent: Union[dict, T.Tensor]):
453
- if isinstance(latent, T.Tensor):
454
- z = latent
455
- else:
456
- z = self.reparam(latent['mu'], latent['logvar'])
457
- l = self.proj_out(z)
458
- return l
459
-
460
- def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
461
- mu, logvar = self.proj_in(x).chunk(2, dim=-1)
462
- kl_div = self.kl_divergence(mu, logvar)
463
- z = self.reparam(mu, logvar)
464
- xhat = self.proj_out(z)
465
- latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
466
- return xhat, latent
467
-
468
-
469
-
470
- @si_module
471
- class WaveCodec(nn.Module):
472
- class Config:
473
- resnet_config: ResNetStack.Config = None
474
- sample_rate: int = 16_000
475
- use_weight_norm: bool = False
476
-
477
- compressor_config: dataclass = None
478
-
479
- norm_stddev: float = 1.0
480
-
481
- def __init__(self, c: Config):
482
- super().__init__()
483
- self.norm_stddev = c.norm_stddev
484
- self.encoder = c.resnet_config(mode='encoder')
485
- self.sample_rate = c.sample_rate
486
-
487
- self.total_stride = 1
488
- for stride in c.resnet_config.strides:
489
- self.total_stride *= stride
490
- self.tokens_per_second = self.sample_rate / self.total_stride
491
-
492
- self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
493
-
494
- self.decoder = c.resnet_config(mode='decoder')
495
-
496
- if c.use_weight_norm:
497
- self.encoder.apply_weight_norm()
498
- self.decoder.apply_weight_norm()
499
- self.encoder.reset_parameters()
500
- self.decoder.reset_parameters()
501
-
502
- def encode(self, data):
503
- return self.encoder(data/self.norm_stddev)
504
-
505
- def decode(self, latent):
506
- return self.decoder(latent.transpose(1, 2))*self.norm_stddev
507
-
508
- @T.no_grad()
509
- def latent_from_data(self, data, get_parameters=False):
510
- x = self.encode(data)
511
- l_in = x.transpose(1, 2)
512
- l, latent = self.compressor(l_in)
513
- return latent['z'] if not get_parameters else {
514
- 'mu': latent['mu'],
515
- 'logvar': latent['logvar'],
516
- 'z': latent['z'],
517
- }
518
-
519
- @T.no_grad()
520
- def data_from_latent(self, latent):
521
- l = self.compressor.repr_from_latent(latent)
522
- x = self.decode(l)
523
- return x
524
-
525
- def process(self, x):
526
- return self.latent_from_data(x)
527
-
528
- def unprocess(self, latent):
529
- return self.data_from_latent(latent)
530
-
531
- def forward(self, audio_input):
532
- x = self.encode(audio_input)
533
-
534
- l_in = x.transpose(1, 2)
535
- l, latent = self.compressor(l_in)
536
-
537
- xhat = self.decode(l)
538
- return xhat, latent
539
-
540
-
541
-
542
- def make_tokenizer(device='cuda'):
543
- generator_config = WaveCodec.Config(
544
- resnet_config=ResNetStack.Config(
545
- input_channels=1,
546
- output_channels=1,
547
- encode_channels=16,
548
- decode_channel_multiplier=4,
549
- kernel_size=7,
550
- bias=True,
551
- channel_ratios=(4, 8, 16, 16, 16, 16),
552
- strides=(2, 2, 4, 5, 5, 5),
553
- mode=None,
554
- ),
555
- use_weight_norm=True,
556
-
557
- compressor_config=GaussianZ.Config(
558
- dim=None,
559
- latent_dim=32,
560
-
561
- bias=True,
562
- use_weight_norm=True
563
- ),
564
-
565
- norm_stddev=0.05,
566
- )
567
- checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
568
-
569
- tokenizer = generator_config()
570
-
571
- load_result = tokenizer.load_state_dict(checkpoint, strict=False)
572
- print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
573
-
574
- tokenizer = tokenizer.eval()
575
- # Only convert to bfloat16 if using CUDA
576
- if device == 'cuda':
577
- tokenizer = tokenizer.bfloat16()
578
- tokenizer = tokenizer.to(device)
579
- tokenizer.requires_grad_ = False
580
- return tokenizer
581
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
transformer.py DELETED
@@ -1,382 +0,0 @@
1
- from typing import Optional, Tuple, MutableMapping
2
- from typing import Union
3
- import math
4
- from contextlib import nullcontext
5
-
6
- import torch
7
- import torch as T
8
- import torch.nn as nn
9
- import torch.nn.functional as F
10
- from torch import Tensor
11
- from torch.nn.attention import SDPBackend
12
-
13
- from einops import rearrange
14
-
15
- from utils import si_module, default, exists, load_ckpt
16
-
17
- CACHE_FILL_VALUE = -1
18
-
19
- def get_cache_len(cache: Optional[Tensor]) -> int:
20
- """
21
- cache: (batch, seq_len, 2, kv_heads, head_dim)
22
- """
23
- if cache is None:
24
- return 0
25
- nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
26
- length = nonzeros.sum(dim=-1).int()
27
- assert T.all(length == length[0])
28
- return length[0]
29
-
30
-
31
- def rotate_half(x):
32
- x1, x2 = x.chunk(2, dim=-1)
33
- return torch.cat((-x2, x1), dim=-1)
34
-
35
-
36
- def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
37
- assert (
38
- cos.shape[1] >= offset + x.shape[1]
39
- ), f"Offset and/or input sequence is too large,\
40
- \n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
41
-
42
- cos_out = cos[:, offset : offset + x.shape[1], :, :]
43
- sin_out = sin[:, offset : offset + x.shape[1], :, :]
44
-
45
- return (x * cos_out) + (rotate_half(x) * sin_out)
46
-
47
-
48
- # Adapted from https://github.com/foundation-model-stack/foundation-model-stack
49
- class ShapeRotator:
50
- def __init__(
51
- self,
52
- dim: int,
53
- end: int,
54
- theta: float = 10_000,
55
- ):
56
- super().__init__()
57
- self.dim = dim
58
- self.ratio = theta
59
- self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
60
- self.max_seq_len_cached: MutableMapping[int, int] = {}
61
- self.ntk_scaling = False
62
- self.max_seq_len = end
63
-
64
- def compute_freqs_cis(self, device, max_seq_len=None):
65
- alpha = 1
66
- dev_idx = device.index
67
- max_seq_len = default(max_seq_len, self.max_seq_len)
68
-
69
- if dev_idx not in self.cached_freqs:
70
- self.cached_freqs[dev_idx] = {}
71
- if dev_idx not in self.max_seq_len_cached:
72
- self.max_seq_len_cached[dev_idx] = 0
73
-
74
-
75
- if self.max_seq_len_cached[dev_idx] > 0:
76
- return 1
77
- max_seq_len = max(max_seq_len, self.max_seq_len)
78
-
79
- if (
80
- 1 in self.cached_freqs[dev_idx]
81
- and max_seq_len <= self.max_seq_len_cached[dev_idx]
82
- ):
83
- return 1
84
-
85
- ratio = self.ratio
86
- dim = self.dim
87
-
88
- freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
89
-
90
- t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
91
- freqs = torch.einsum("i,j->ij", t, freqs)
92
- emb = torch.cat((freqs, freqs), dim=-1).to(device)
93
-
94
- cos_to_cache = emb.cos()[None, :, None, :]
95
- sin_to_cache = emb.sin()[None, :, None, :]
96
-
97
- self.max_seq_len_cached[dev_idx] = max_seq_len
98
-
99
- self.cached_freqs[dev_idx][alpha] = torch.stack(
100
- [
101
- cos_to_cache,
102
- sin_to_cache,
103
- ],
104
- dim=-1,
105
- )
106
-
107
- return alpha
108
-
109
- def rotate(
110
- self,
111
- q: Tensor,
112
- k: Tensor,
113
- offset: int = 0,
114
- ) -> Tuple[Tensor, Tensor]:
115
- """
116
- Args
117
- ----
118
- q : torch.Tensor
119
- Embedded query tensor, expected size is B x S x H x Eh
120
- k : torch.Tensor
121
- Embedded query tensor, expected size is B x S x H x Eh
122
- """
123
- assert len(q.size()) == 4
124
- assert len(k.size()) == 4
125
-
126
- seq_len = self.max_seq_len
127
- alpha = self.compute_freqs_cis(q.device, seq_len)
128
- freqs = self.cached_freqs[q.device.index][alpha]
129
-
130
- freqs = freqs.float() # 1 L D/2 2 2
131
- q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
132
- k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
133
-
134
- return q_out.view_as(q), k_out.view_as(k)
135
-
136
- class Linear(nn.Linear):
137
- def __init__(self, *args, **kwargs):
138
- super().__init__(*args, **kwargs, bias=False)
139
-
140
- class Norm(nn.Module):
141
- def __init__(self,
142
- dim: int,
143
- eps: float = 1e-5,) -> None:
144
- super().__init__()
145
- self.eps = eps
146
- self.weight = nn.Parameter(T.ones((dim,)))
147
-
148
- def forward(self, input: Tensor) -> Tensor:
149
- return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
150
-
151
-
152
- class FFNN(nn.Module):
153
- def __init__(self,
154
- dim: int,
155
- expand_dim: int = None,):
156
- super().__init__()
157
- expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
158
- self.dim = dim
159
- self.expand_dim = expand_dim
160
-
161
- self.gateup_proj = Linear(dim, 2*expand_dim)
162
- self.down_proj = Linear(expand_dim, dim)
163
-
164
- def forward(self, x):
165
- gate, up = self.gateup_proj(x).chunk(2, dim=-1)
166
- return self.down_proj(up * F.silu(gate))
167
-
168
- class GQA(nn.Module):
169
- def __init__(self,
170
- dim: int,
171
- n_head: int,
172
- shape_rotator: ShapeRotator,
173
- kv_heads: Optional[int] = None,
174
- eps: float = 1e-5,
175
- causal: bool = True,):
176
- super().__init__()
177
- self.n_heads = n_head
178
- self.kv_heads = default(kv_heads, n_head)
179
- self.head_dim = dim // n_head
180
- self.causal = causal
181
-
182
- self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
183
-
184
- self.norm_q = Norm(self.head_dim*n_head, eps=eps)
185
- self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
186
-
187
- self.attn_out = Linear(dim, dim)
188
-
189
- self.shape_rotator = shape_rotator
190
-
191
- def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
192
- k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
193
- v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
194
- with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext():
195
- x = F.scaled_dot_product_attention(
196
- q.transpose(1, 2),
197
- k.transpose(1, 2),
198
- v.transpose(1, 2),
199
- is_causal=False if (q.size(1) != k.size(1)) else self.causal,
200
- )
201
- x = x.transpose(1, 2).contiguous()
202
- return x
203
-
204
- def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
205
- cache_len = get_cache_len(kv_cache)
206
- q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
207
- if exists(kv_cache):
208
- k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
209
- v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
210
- kv_cache[:, :k.size(1), 0] = k
211
- kv_cache[:, :v.size(1), 1] = v
212
- x = self._sdpa(q, k, v)
213
- return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
214
-
215
- def _project(self, x):
216
- full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
217
- normed_full_q = self.norm_q(full_q).to(full_q.dtype)
218
- normed_full_k = self.norm_k(full_k).to(full_k.dtype)
219
-
220
- q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
221
- k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
222
- v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
223
- return q, k, v
224
-
225
- def forward(self,
226
- x: Tensor,
227
- kv: Optional[Tensor] = None,):
228
- """
229
- x: (B, S, D)
230
- kv: (B, S, H, D)
231
- """
232
- q, k, v = self._project(x)
233
- return self._attend(q, k, v, kv_cache=kv)
234
-
235
-
236
- class PreNormAttn(nn.Module):
237
- def __init__(self,
238
- dim: int,
239
- n_head: int,
240
- shape_rotator: ShapeRotator,
241
- kv_heads: Optional[int] = None,
242
- eps: float = 1e-5,
243
- causal: bool = True,):
244
- super().__init__()
245
- self.attn_norm = Norm(dim, eps=eps)
246
- self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
247
-
248
- def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
249
- """
250
- x: (B, S, D)
251
- kv: (B, S, H, D)
252
- """
253
- return x + self.attn(self.attn_norm(x), kv)
254
-
255
- class PreNormFFNN(nn.Module):
256
- def __init__(self,
257
- dim: int,
258
- ff_dim: int,
259
- eps: float = 1e-5,):
260
- super().__init__()
261
- self.ffnn_norm = Norm(dim, eps=eps)
262
- self.ffnn = FFNN(dim, ff_dim)
263
-
264
- def forward(self, x: Tensor) -> Tensor:
265
- return x + self.ffnn(self.ffnn_norm(x))
266
-
267
- class Block(nn.Module):
268
- def __init__(self,
269
- dim: int,
270
- layer_id: int = 0,
271
- n_head: int = 16,
272
- kv_heads: Optional[int] = None,
273
- ff_dim: Optional[int] = None,
274
- eps: float = 1e-5,
275
- causal: bool = True,
276
- shape_rotator: ShapeRotator = None):
277
- super().__init__()
278
- self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
279
- self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
280
- self.dim = dim
281
- self.layer_id = layer_id
282
- self.head_dim = dim // n_head
283
- self.expand_dim = self.ffnn.ffnn.expand_dim
284
-
285
- self.reset_parameters()
286
-
287
- def reset_parameters(self):
288
- std = 1.0 / math.sqrt(self.dim)
289
- nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
290
- nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
291
- nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
292
-
293
- xstd = 1.0 / math.sqrt(self.expand_dim)
294
- nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
295
-
296
- def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
297
- """
298
- x: (B, S, D)
299
- kv: (B, S, H, D)
300
- """
301
- h = self.attn(x, kv)
302
- out = self.ffnn(h)
303
- return out
304
-
305
-
306
-
307
- class GPTOutput(nn.Module):
308
- def __init__(self, dim, vocab_size):
309
- super().__init__()
310
- self.dim = dim
311
- self.norm = Norm(dim)
312
- self.output = Linear(dim, vocab_size)
313
-
314
- self.reset_parameters()
315
-
316
- def reset_parameters(self):
317
- std = 1.0 / math.sqrt(self.dim**2)
318
- nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
319
-
320
- def forward(self, x):
321
- return self.output(self.norm(x))
322
-
323
- @si_module
324
- class Stack(nn.Module):
325
- class Config:
326
- layers: int
327
- dim: int
328
- seq_len: int
329
- n_head: int = 32
330
- ff_dim: int = None
331
- kv_heads: int = None
332
- eps: float = 1e-5
333
- theta: Union[int, float] = 10_000
334
- causal: bool = True
335
-
336
- from_pretrained: Optional[Tuple[str, int]] = None
337
-
338
- def __init__(self, c: Config):
339
- super().__init__()
340
-
341
- from_pretrained = c.from_pretrained
342
- if exists(from_pretrained):
343
- checkpoint = load_ckpt(c.from_pretrained)
344
-
345
- self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
346
-
347
- self.layers = nn.ModuleList([
348
- Block(
349
- dim=c.dim,
350
- layer_id=l,
351
- n_head=c.n_head,
352
- kv_heads=c.kv_heads,
353
- ff_dim=c.ff_dim,
354
- eps=c.eps,
355
- causal=c.causal,
356
- shape_rotator=self.shape_rotator,
357
- ) for l in range(c.layers)
358
- ])
359
-
360
- kv_heads = c.kv_heads or c.n_head
361
- head_dim = c.dim // c.n_head
362
- cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
363
- self.cache_shape = cache_shape
364
- self.cache = [None] * c.layers
365
-
366
- if exists(from_pretrained):
367
- self.load_state_dict(checkpoint)
368
-
369
- def init_cache(self, bsize, device, dtype, length:int=None):
370
- if self.cache_shape is None:
371
- return
372
- cache_shape = self.cache_shape.copy()
373
- cache_shape[1] = length or cache_shape[1]
374
- self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
375
-
376
- def deinit_cache(self):
377
- self.cache = [None] * len(self.cache)
378
-
379
- def forward(self, x: Tensor) -> Tensor:
380
- for l, layer in enumerate(self.layers):
381
- x = layer(x, kv=self.cache[l])
382
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/__init__.py DELETED
@@ -1,3 +0,0 @@
1
- from .blocks import *
2
- from .dist import *
3
- from .interp import *
 
 
 
 
utils/blocks.py DELETED
@@ -1,92 +0,0 @@
1
- from dataclasses import dataclass
2
- from typing import TypeVar, Generic, Type, Optional
3
- from functools import wraps
4
- import time
5
- import random
6
-
7
- import torch as T
8
- import torch.nn as nn
9
-
10
- # @TODO: remove si_module from codebase
11
- # we use this in our research codebase to make modules from callable configs
12
- si_module_TpV = TypeVar('si_module_TpV')
13
- def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
14
- if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
15
- class Config:
16
- pass
17
- cls.Config = Config
18
-
19
- cls.Config = dataclass(cls.Config)
20
-
21
- class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
22
- def __call__(self, *args, **kwargs) -> si_module_TpV:
23
- if len(kwargs) > 0:
24
- config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
25
- config_dict.update(kwargs)
26
- new_config = type(self)(**config_dict)
27
- return cls(new_config)
28
- else:
29
- return cls(self, *args)
30
-
31
- ConfigWrapper.__module__ = cls.__module__
32
- ConfigWrapper.__name__ = f"{cls.__name__}Config"
33
- ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
34
-
35
- cls.Config = ConfigWrapper
36
-
37
- original_init = cls.__init__
38
- def new_init(self, *args, **kwargs):
39
- self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
40
- original_init(self, *args, **kwargs)
41
- self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
42
-
43
- cls.__init__ = new_init
44
-
45
- @property
46
- def device(self):
47
- return self._device_tracker.device
48
-
49
- @property
50
- def dtype(self):
51
- return self._device_tracker.dtype
52
-
53
- cls.device = device
54
- cls.dtype = dtype
55
-
56
- return cls
57
-
58
-
59
- def get_activation(nonlinear_activation, nonlinear_activation_params={}):
60
- if hasattr(nn, nonlinear_activation):
61
- return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
62
- else:
63
- raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
64
-
65
-
66
- def exists(v):
67
- return v is not None
68
-
69
- def isnt(v):
70
- return not exists(v)
71
-
72
- def truthyexists(v):
73
- return exists(v) and v is not False
74
-
75
- def truthyattr(obj, attr):
76
- return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
77
-
78
- defaultT = TypeVar('defaultT')
79
-
80
- def default(*args: Optional[defaultT]) -> Optional[defaultT]:
81
- for arg in args:
82
- if exists(arg):
83
- return arg
84
- return None
85
-
86
- def maybe(fn):
87
- @wraps(fn)
88
- def inner(x, *args, **kwargs):
89
- if not exists(x):
90
- return x
91
- return fn(x, *args, **kwargs)
92
- return inner
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/dist.py DELETED
@@ -1,99 +0,0 @@
1
- import os
2
- import torch as T
3
- import re
4
- from tqdm import tqdm
5
- from datetime import timedelta
6
-
7
- import requests
8
- import hashlib
9
-
10
- from io import BytesIO
11
-
12
- def rank0():
13
- rank = os.environ.get('RANK')
14
- if rank is None or rank == '0':
15
- return True
16
- else:
17
- return False
18
-
19
- def local0():
20
- local_rank = os.environ.get('LOCAL_RANK')
21
- if local_rank is None or local_rank == '0':
22
- return True
23
- else:
24
- return False
25
- class tqdm0(tqdm):
26
- def __init__(self, *args, **kwargs):
27
- total = kwargs.get('total', None)
28
- if total is None and len(args) > 0:
29
- try:
30
- total = len(args[0])
31
- except TypeError:
32
- pass
33
- if total is not None:
34
- kwargs['miniters'] = max(1, total // 20)
35
- super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
36
-
37
- def print0(*args, **kwargs):
38
- if rank0():
39
- print(*args, **kwargs)
40
-
41
- _PRINTED_IDS = set()
42
-
43
- def printonce(*args, id=None, **kwargs):
44
- if id is None:
45
- id = ' '.join(map(str, args))
46
-
47
- if id not in _PRINTED_IDS:
48
- print(*args, **kwargs)
49
- _PRINTED_IDS.add(id)
50
-
51
- def print0once(*args, **kwargs):
52
- if rank0():
53
- printonce(*args, **kwargs)
54
-
55
- def init_dist():
56
- if T.distributed.is_initialized():
57
- print0('Distributed already initialized')
58
- rank = T.distributed.get_rank()
59
- local_rank = int(os.environ.get('LOCAL_RANK', 0))
60
- world_size = T.distributed.get_world_size()
61
- else:
62
- try:
63
- rank = int(os.environ['RANK'])
64
- local_rank = int(os.environ['LOCAL_RANK'])
65
- world_size = int(os.environ['WORLD_SIZE'])
66
- device = f'cuda:{local_rank}'
67
- T.cuda.set_device(device)
68
- T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
69
- print(f'Rank {rank} of {world_size}.')
70
- except Exception as e:
71
- print0once(f'Not initializing distributed env: {e}')
72
- rank = 0
73
- local_rank = 0
74
- world_size = 1
75
- return rank, local_rank, world_size
76
-
77
- def load_ckpt(load_from_location, expected_hash=None):
78
- if local0():
79
- os.makedirs('ckpt', exist_ok=True)
80
- url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt"
81
- save_path = f"ckpt/{load_from_location}.pt"
82
- if not os.path.exists(save_path):
83
- response = requests.get(url, stream=True)
84
- total_size = int(response.headers.get('content-length', 0))
85
- with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
86
- for chunk in response.iter_content(chunk_size=8192):
87
- f.write(chunk)
88
- pbar.update(len(chunk))
89
- if expected_hash is not None:
90
- with open(save_path, 'rb') as f:
91
- file_hash = hashlib.md5(f.read()).hexdigest()
92
- if file_hash != expected_hash:
93
- print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
94
- os.remove(save_path)
95
- return load_ckpt(load_from_location, expected_hash)
96
- if T.distributed.is_initialized():
97
- T.distributed.barrier() # so that ranks don't try to laod checkpoint before it's finished downloading
98
- loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
99
- return loaded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
utils/interp.py DELETED
@@ -1,84 +0,0 @@
1
- import torch as T
2
- import os
3
-
4
- def rank0():
5
- rank = os.environ.get('RANK')
6
- if rank is None or rank == '0':
7
- return True
8
- else:
9
- return False
10
-
11
- def print_colored(message, color='reset', bold=False, **kwargs):
12
- color_dict = {
13
- 'bold': '\033[1m',
14
- 'green': '\033[92m',
15
- 'yellow': '\033[93m',
16
- 'red': '\033[91m',
17
- 'blue': '\033[94m',
18
- 'grey': '\033[90m',
19
- 'white': '\033[97m',
20
- 'reset': '\033[0m'
21
- }
22
-
23
- color_code = color_dict.get(color.lower(), color_dict['reset'])
24
- prefix = color_dict['bold'] if bold else ''
25
- print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
26
-
27
- def print0_colored(*args, **kwargs):
28
- if rank0():
29
- print_colored(*args, **kwargs)
30
-
31
- def param_count(module):
32
- def count_parameters(model):
33
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
34
-
35
- total_params = count_parameters(module)
36
- output = [f'Total model parameters: {total_params:,}', '---------------------------']
37
-
38
- for name, child in module.named_children():
39
- params = count_parameters(child)
40
- output.append(f'{name} parameters: {params:,}')
41
-
42
- return '\n'.join(output)
43
-
44
- def model_size_estimation(module):
45
- def estimate_size(model):
46
- param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
47
- buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
48
- return param_size + buffer_size
49
-
50
- total_size = estimate_size(module)
51
- output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
52
-
53
- for name, child in module.named_children():
54
- child_size = estimate_size(child)
55
- output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
56
-
57
- return '\n'.join(output)
58
-
59
- def layer_param_distribution(module):
60
- def count_parameters(model):
61
- return sum(p.numel() for p in model.parameters() if p.requires_grad)
62
-
63
- def get_layer_types(model):
64
- layer_types = {}
65
- for name, module in model.named_modules():
66
- layer_type = module.__class__.__name__
67
- params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
68
- if params > 0:
69
- if layer_type not in layer_types:
70
- layer_types[layer_type] = 0
71
- layer_types[layer_type] += params
72
- return layer_types
73
-
74
- total_params = count_parameters(module)
75
- layer_types = get_layer_types(module)
76
-
77
- output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
78
-
79
- for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
80
- percentage = (count / total_params) * 100
81
- output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
82
-
83
- return '\n'.join(output)
84
-