Spaces:
Build error
Build error
Commit
·
775c5c6
1
Parent(s):
fc3fbd4
empty
Browse files- .gitattributes +0 -37
- LICENSE +0 -201
- README.md +0 -13
- app.py +0 -142
- inference.ipynb +0 -236
- inference_client.py +0 -161
- inference_server.py +0 -170
- ioblocks.py +0 -333
- model.py +0 -441
- prompts/bob_duo.wav +0 -0
- prompts/bob_mono.wav +0 -0
- prompts/countdown_mono.wav +0 -3
- prompts/toaskanymore.wav +0 -3
- requirements.txt +0 -14
- tokenizer.py +0 -581
- transformer.py +0 -382
- utils/__init__.py +0 -3
- utils/blocks.py +0 -92
- utils/dist.py +0 -99
- utils/interp.py +0 -84
.gitattributes
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
-
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
-
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
-
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
-
prompts/countdown_mono.wav filter=lfs diff=lfs merge=lfs -text
|
37 |
-
prompts/toaskanymore.wav filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,201 +0,0 @@
|
|
1 |
-
Apache License
|
2 |
-
Version 2.0, January 2004
|
3 |
-
http://www.apache.org/licenses/
|
4 |
-
|
5 |
-
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
-
|
7 |
-
1. Definitions.
|
8 |
-
|
9 |
-
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
-
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
-
|
12 |
-
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
-
the copyright owner that is granting the License.
|
14 |
-
|
15 |
-
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
-
other entities that control, are controlled by, or are under common
|
17 |
-
control with that entity. For the purposes of this definition,
|
18 |
-
"control" means (i) the power, direct or indirect, to cause the
|
19 |
-
direction or management of such entity, whether by contract or
|
20 |
-
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
-
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
-
|
23 |
-
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
-
exercising permissions granted by this License.
|
25 |
-
|
26 |
-
"Source" form shall mean the preferred form for making modifications,
|
27 |
-
including but not limited to software source code, documentation
|
28 |
-
source, and configuration files.
|
29 |
-
|
30 |
-
"Object" form shall mean any form resulting from mechanical
|
31 |
-
transformation or translation of a Source form, including but
|
32 |
-
not limited to compiled object code, generated documentation,
|
33 |
-
and conversions to other media types.
|
34 |
-
|
35 |
-
"Work" shall mean the work of authorship, whether in Source or
|
36 |
-
Object form, made available under the License, as indicated by a
|
37 |
-
copyright notice that is included in or attached to the work
|
38 |
-
(an example is provided in the Appendix below).
|
39 |
-
|
40 |
-
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
-
form, that is based on (or derived from) the Work and for which the
|
42 |
-
editorial revisions, annotations, elaborations, or other modifications
|
43 |
-
represent, as a whole, an original work of authorship. For the purposes
|
44 |
-
of this License, Derivative Works shall not include works that remain
|
45 |
-
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
-
the Work and Derivative Works thereof.
|
47 |
-
|
48 |
-
"Contribution" shall mean any work of authorship, including
|
49 |
-
the original version of the Work and any modifications or additions
|
50 |
-
to that Work or Derivative Works thereof, that is intentionally
|
51 |
-
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
-
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
-
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
-
means any form of electronic, verbal, or written communication sent
|
55 |
-
to the Licensor or its representatives, including but not limited to
|
56 |
-
communication on electronic mailing lists, source code control systems,
|
57 |
-
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
-
Licensor for the purpose of discussing and improving the Work, but
|
59 |
-
excluding communication that is conspicuously marked or otherwise
|
60 |
-
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
-
|
62 |
-
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
-
on behalf of whom a Contribution has been received by Licensor and
|
64 |
-
subsequently incorporated within the Work.
|
65 |
-
|
66 |
-
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
-
this License, each Contributor hereby grants to You a perpetual,
|
68 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
-
copyright license to reproduce, prepare Derivative Works of,
|
70 |
-
publicly display, publicly perform, sublicense, and distribute the
|
71 |
-
Work and such Derivative Works in Source or Object form.
|
72 |
-
|
73 |
-
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
-
this License, each Contributor hereby grants to You a perpetual,
|
75 |
-
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
-
(except as stated in this section) patent license to make, have made,
|
77 |
-
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
-
where such license applies only to those patent claims licensable
|
79 |
-
by such Contributor that are necessarily infringed by their
|
80 |
-
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
-
with the Work to which such Contribution(s) was submitted. If You
|
82 |
-
institute patent litigation against any entity (including a
|
83 |
-
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
-
or a Contribution incorporated within the Work constitutes direct
|
85 |
-
or contributory patent infringement, then any patent licenses
|
86 |
-
granted to You under this License for that Work shall terminate
|
87 |
-
as of the date such litigation is filed.
|
88 |
-
|
89 |
-
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
-
Work or Derivative Works thereof in any medium, with or without
|
91 |
-
modifications, and in Source or Object form, provided that You
|
92 |
-
meet the following conditions:
|
93 |
-
|
94 |
-
(a) You must give any other recipients of the Work or
|
95 |
-
Derivative Works a copy of this License; and
|
96 |
-
|
97 |
-
(b) You must cause any modified files to carry prominent notices
|
98 |
-
stating that You changed the files; and
|
99 |
-
|
100 |
-
(c) You must retain, in the Source form of any Derivative Works
|
101 |
-
that You distribute, all copyright, patent, trademark, and
|
102 |
-
attribution notices from the Source form of the Work,
|
103 |
-
excluding those notices that do not pertain to any part of
|
104 |
-
the Derivative Works; and
|
105 |
-
|
106 |
-
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
-
distribution, then any Derivative Works that You distribute must
|
108 |
-
include a readable copy of the attribution notices contained
|
109 |
-
within such NOTICE file, excluding those notices that do not
|
110 |
-
pertain to any part of the Derivative Works, in at least one
|
111 |
-
of the following places: within a NOTICE text file distributed
|
112 |
-
as part of the Derivative Works; within the Source form or
|
113 |
-
documentation, if provided along with the Derivative Works; or,
|
114 |
-
within a display generated by the Derivative Works, if and
|
115 |
-
wherever such third-party notices normally appear. The contents
|
116 |
-
of the NOTICE file are for informational purposes only and
|
117 |
-
do not modify the License. You may add Your own attribution
|
118 |
-
notices within Derivative Works that You distribute, alongside
|
119 |
-
or as an addendum to the NOTICE text from the Work, provided
|
120 |
-
that such additional attribution notices cannot be construed
|
121 |
-
as modifying the License.
|
122 |
-
|
123 |
-
You may add Your own copyright statement to Your modifications and
|
124 |
-
may provide additional or different license terms and conditions
|
125 |
-
for use, reproduction, or distribution of Your modifications, or
|
126 |
-
for any such Derivative Works as a whole, provided Your use,
|
127 |
-
reproduction, and distribution of the Work otherwise complies with
|
128 |
-
the conditions stated in this License.
|
129 |
-
|
130 |
-
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
-
any Contribution intentionally submitted for inclusion in the Work
|
132 |
-
by You to the Licensor shall be under the terms and conditions of
|
133 |
-
this License, without any additional terms or conditions.
|
134 |
-
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
-
the terms of any separate license agreement you may have executed
|
136 |
-
with Licensor regarding such Contributions.
|
137 |
-
|
138 |
-
6. Trademarks. This License does not grant permission to use the trade
|
139 |
-
names, trademarks, service marks, or product names of the Licensor,
|
140 |
-
except as required for reasonable and customary use in describing the
|
141 |
-
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
-
|
143 |
-
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
-
agreed to in writing, Licensor provides the Work (and each
|
145 |
-
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
-
implied, including, without limitation, any warranties or conditions
|
148 |
-
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
-
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
-
appropriateness of using or redistributing the Work and assume any
|
151 |
-
risks associated with Your exercise of permissions under this License.
|
152 |
-
|
153 |
-
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
-
whether in tort (including negligence), contract, or otherwise,
|
155 |
-
unless required by applicable law (such as deliberate and grossly
|
156 |
-
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
-
liable to You for damages, including any direct, indirect, special,
|
158 |
-
incidental, or consequential damages of any character arising as a
|
159 |
-
result of this License or out of the use or inability to use the
|
160 |
-
Work (including but not limited to damages for loss of goodwill,
|
161 |
-
work stoppage, computer failure or malfunction, or any and all
|
162 |
-
other commercial damages or losses), even if such Contributor
|
163 |
-
has been advised of the possibility of such damages.
|
164 |
-
|
165 |
-
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
-
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
-
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
-
or other liability obligations and/or rights consistent with this
|
169 |
-
License. However, in accepting such obligations, You may act only
|
170 |
-
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
-
of any other Contributor, and only if You agree to indemnify,
|
172 |
-
defend, and hold each Contributor harmless for any liability
|
173 |
-
incurred by, or claims asserted against, such Contributor by reason
|
174 |
-
of your accepting any such warranty or additional liability.
|
175 |
-
|
176 |
-
END OF TERMS AND CONDITIONS
|
177 |
-
|
178 |
-
APPENDIX: How to apply the Apache License to your work.
|
179 |
-
|
180 |
-
To apply the Apache License to your work, attach the following
|
181 |
-
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
-
replaced with your own identifying information. (Don't include
|
183 |
-
the brackets!) The text should be enclosed in the appropriate
|
184 |
-
comment syntax for the file format. We also recommend that a
|
185 |
-
file or class name and description of purpose be included on the
|
186 |
-
same "printed page" as the copyright notice for easier
|
187 |
-
identification within third-party archives.
|
188 |
-
|
189 |
-
Copyright 2024 Standard Intelligence PBC
|
190 |
-
|
191 |
-
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
-
you may not use this file except in compliance with the License.
|
193 |
-
You may obtain a copy of the License at
|
194 |
-
|
195 |
-
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
-
|
197 |
-
Unless required by applicable law or agreed to in writing, software
|
198 |
-
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
-
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
-
See the License for the specific language governing permissions and
|
201 |
-
limitations under the License.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
README.md
DELETED
@@ -1,13 +0,0 @@
|
|
1 |
-
---
|
2 |
-
title: hertz-dev
|
3 |
-
emoji: ⚡
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: gray
|
6 |
-
sdk: gradio
|
7 |
-
sdk_version: 5.4.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
-
short_description: Unofficial demo for hertz-dev
|
11 |
-
---
|
12 |
-
|
13 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
app.py
DELETED
@@ -1,142 +0,0 @@
|
|
1 |
-
import torch as T
|
2 |
-
import torch.nn as nn
|
3 |
-
import torch.nn.functional as F
|
4 |
-
import torchaudio
|
5 |
-
from utils import load_ckpt, print_colored
|
6 |
-
from tokenizer import make_tokenizer
|
7 |
-
from model import get_hertz_dev_config
|
8 |
-
import matplotlib.pyplot as plt
|
9 |
-
import spaces
|
10 |
-
import gradio as gr
|
11 |
-
|
12 |
-
device = 'cuda' if T.cuda.is_available() else 'cpu'
|
13 |
-
#T.cuda.set_device(0)
|
14 |
-
print(f"Using device: {device}")
|
15 |
-
|
16 |
-
audio_tokenizer = make_tokenizer(device)
|
17 |
-
|
18 |
-
TWO_SPEAKER = False
|
19 |
-
|
20 |
-
model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)
|
21 |
-
|
22 |
-
generator = model_config()
|
23 |
-
generator = generator.eval().to(T.bfloat16).to(device)
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
##############
|
28 |
-
# Load audio
|
29 |
-
|
30 |
-
def load_and_preprocess_audio(audio_path):
|
31 |
-
gr.Info("Loading and preprocessing audio...")
|
32 |
-
# Load audio file
|
33 |
-
audio_tensor, sr = torchaudio.load(audio_path)
|
34 |
-
gr.Info(f"Loaded audio shape: {audio_tensor.shape}")
|
35 |
-
|
36 |
-
if TWO_SPEAKER:
|
37 |
-
if audio_tensor.shape[0] == 1:
|
38 |
-
gr.Info("Converting mono to stereo...")
|
39 |
-
audio_tensor = audio_tensor.repeat(2, 1)
|
40 |
-
gr.Info(f"Stereo audio shape: {audio_tensor.shape}")
|
41 |
-
else:
|
42 |
-
if audio_tensor.shape[0] == 2:
|
43 |
-
gr.Info("Converting stereo to mono...")
|
44 |
-
audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)
|
45 |
-
gr.Info(f"Mono audio shape: {audio_tensor.shape}")
|
46 |
-
|
47 |
-
# Resample to 16kHz if needed
|
48 |
-
if sr != 16000:
|
49 |
-
gr.Info(f"Resampling from {sr}Hz to 16000Hz...")
|
50 |
-
resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)
|
51 |
-
audio_tensor = resampler(audio_tensor)
|
52 |
-
|
53 |
-
# Clip to 5 minutes if needed
|
54 |
-
max_samples = 16000 * 60 * 5
|
55 |
-
if audio_tensor.shape[1] > max_samples:
|
56 |
-
# gr.Info("Clipping audio to 5 minutes...")
|
57 |
-
raise gr.Erorr("Maximum prompt is 5 minutes")
|
58 |
-
# audio_tensor = audio_tensor[:, :max_samples]
|
59 |
-
|
60 |
-
duration_seconds = audio_tensor.shape[1] / sr
|
61 |
-
|
62 |
-
gr.Info("Audio preprocessing complete!")
|
63 |
-
return audio_tensor.unsqueeze(0), duration_seconds
|
64 |
-
|
65 |
-
##############
|
66 |
-
# Return audio to gradio
|
67 |
-
|
68 |
-
def display_audio(audio_tensor):
|
69 |
-
audio_tensor = audio_tensor.cpu().squeeze()
|
70 |
-
if audio_tensor.ndim == 1:
|
71 |
-
audio_tensor = audio_tensor.unsqueeze(0)
|
72 |
-
audio_tensor = audio_tensor.float()
|
73 |
-
|
74 |
-
# Make a waveform plot
|
75 |
-
# plt.figure(figsize=(4, 1))
|
76 |
-
# plt.plot(audio_tensor.numpy()[0], linewidth=0.5)
|
77 |
-
# plt.axis('off')
|
78 |
-
# plt.show()
|
79 |
-
|
80 |
-
# Make an audio player
|
81 |
-
return (16000, audio_tensor.numpy())
|
82 |
-
|
83 |
-
def get_completion(encoded_prompt_audio, prompt_len):
|
84 |
-
prompt_len_seconds = prompt_len / 8
|
85 |
-
gr.Info(f"Prompt length: {prompt_len_seconds:.2f}s")
|
86 |
-
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
87 |
-
completed_audio_batch = generator.completion(
|
88 |
-
encoded_prompt_audio,
|
89 |
-
temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))
|
90 |
-
use_cache=True)
|
91 |
-
|
92 |
-
completed_audio = completed_audio_batch
|
93 |
-
print_colored(f"Decoding completion...", "blue")
|
94 |
-
if TWO_SPEAKER:
|
95 |
-
decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())
|
96 |
-
decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())
|
97 |
-
decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)
|
98 |
-
else:
|
99 |
-
decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())
|
100 |
-
gr.Info(f"Decoded completion shape: {decoded_completion.shape}")
|
101 |
-
|
102 |
-
gr.Info("Preparing audio for playback...")
|
103 |
-
|
104 |
-
audio_tensor = decoded_completion.cpu().squeeze()
|
105 |
-
if audio_tensor.ndim == 1:
|
106 |
-
audio_tensor = audio_tensor.unsqueeze(0)
|
107 |
-
audio_tensor = audio_tensor.float()
|
108 |
-
|
109 |
-
if audio_tensor.abs().max() > 1:
|
110 |
-
audio_tensor = audio_tensor / audio_tensor.abs().max()
|
111 |
-
|
112 |
-
# return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]
|
113 |
-
return audio_tensor
|
114 |
-
|
115 |
-
@spaces.GPU
|
116 |
-
def run(audio_path):
|
117 |
-
prompt_audio, prompt_len_seconds = load_and_preprocess_audio(audio_path)
|
118 |
-
prompt_len = prompt_len_seconds * 8
|
119 |
-
gr.Info("Encoding prompt...")
|
120 |
-
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
121 |
-
if TWO_SPEAKER:
|
122 |
-
encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))
|
123 |
-
encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))
|
124 |
-
encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)
|
125 |
-
else:
|
126 |
-
encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))
|
127 |
-
gr.Info(f"Encoded prompt shape: {encoded_prompt_audio.shape}")
|
128 |
-
gr.Info("Prompt encoded successfully!")
|
129 |
-
# num_completions = 10
|
130 |
-
completion = get_completion(encoded_prompt_audio, prompt_len)
|
131 |
-
return display_audio(completion)
|
132 |
-
|
133 |
-
|
134 |
-
|
135 |
-
with gr.Blocks() as demo:
|
136 |
-
gr.Markdown("# hertz-dev")
|
137 |
-
inp = gr.Audio(label="Input Audio", type="filepath", interactive=True)
|
138 |
-
btn = gr.Button("Continue", variant="primary")
|
139 |
-
out = gr.Audio(label="Output", interactive=False)
|
140 |
-
btn.click(run, inputs=inp, outputs=out)
|
141 |
-
|
142 |
-
demo.queue().launch()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference.ipynb
DELETED
@@ -1,236 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"cells": [
|
3 |
-
{
|
4 |
-
"cell_type": "code",
|
5 |
-
"execution_count": null,
|
6 |
-
"metadata": {},
|
7 |
-
"outputs": [],
|
8 |
-
"source": [
|
9 |
-
"%load_ext autoreload\n",
|
10 |
-
"%autoreload 2"
|
11 |
-
]
|
12 |
-
},
|
13 |
-
{
|
14 |
-
"cell_type": "code",
|
15 |
-
"execution_count": null,
|
16 |
-
"metadata": {},
|
17 |
-
"outputs": [],
|
18 |
-
"source": [
|
19 |
-
"import torch as T\n",
|
20 |
-
"import torch.nn as nn\n",
|
21 |
-
"import torch.nn.functional as F\n",
|
22 |
-
"import torchaudio\n",
|
23 |
-
"from utils import load_ckpt, print_colored\n",
|
24 |
-
"from tokenizer import make_tokenizer\n",
|
25 |
-
"from model import get_hertz_dev_config\n",
|
26 |
-
"import matplotlib.pyplot as plt\n",
|
27 |
-
"from IPython.display import Audio, display\n",
|
28 |
-
"\n",
|
29 |
-
"\n",
|
30 |
-
"# If you get an error like \"undefined symbol: __nvJitLinkComplete_12_4, version libnvJitLink.so.12\",\n",
|
31 |
-
"# you need to install PyTorch with the correct CUDA version. Run:\n",
|
32 |
-
"# `pip3 uninstall torch torchaudio && pip3 install torch torchaudio --index-url https://download.pytorch.org/whl/cu121`\n",
|
33 |
-
"\n",
|
34 |
-
"device = 'cuda' if T.cuda.is_available() else 'cpu'\n",
|
35 |
-
"T.cuda.set_device(0)\n",
|
36 |
-
"print_colored(f\"Using device: {device}\", \"grey\")"
|
37 |
-
]
|
38 |
-
},
|
39 |
-
{
|
40 |
-
"cell_type": "code",
|
41 |
-
"execution_count": null,
|
42 |
-
"metadata": {},
|
43 |
-
"outputs": [],
|
44 |
-
"source": [
|
45 |
-
"# If you've already downloaded the model checkpoints, save them in ckpt/.\n",
|
46 |
-
"# This code will automatically download them if it can't find them.\n",
|
47 |
-
"audio_tokenizer = make_tokenizer(device)"
|
48 |
-
]
|
49 |
-
},
|
50 |
-
{
|
51 |
-
"cell_type": "code",
|
52 |
-
"execution_count": 7,
|
53 |
-
"metadata": {},
|
54 |
-
"outputs": [],
|
55 |
-
"source": [
|
56 |
-
"# We have different checkpoints for the single-speaker and two-speaker models\n",
|
57 |
-
"# Set to True to load and run inference with the two-speaker model\n",
|
58 |
-
"TWO_SPEAKER = False"
|
59 |
-
]
|
60 |
-
},
|
61 |
-
{
|
62 |
-
"cell_type": "code",
|
63 |
-
"execution_count": null,
|
64 |
-
"metadata": {},
|
65 |
-
"outputs": [],
|
66 |
-
"source": [
|
67 |
-
"model_config = get_hertz_dev_config(is_split=TWO_SPEAKER)\n",
|
68 |
-
"\n",
|
69 |
-
"generator = model_config()\n",
|
70 |
-
"generator = generator.eval().to(T.bfloat16).to(device)"
|
71 |
-
]
|
72 |
-
},
|
73 |
-
{
|
74 |
-
"cell_type": "code",
|
75 |
-
"execution_count": null,
|
76 |
-
"metadata": {},
|
77 |
-
"outputs": [],
|
78 |
-
"source": [
|
79 |
-
"def load_and_preprocess_audio(audio_path):\n",
|
80 |
-
" print_colored(\"Loading and preprocessing audio...\", \"blue\", bold=True)\n",
|
81 |
-
" # Load audio file\n",
|
82 |
-
" audio_tensor, sr = torchaudio.load(audio_path)\n",
|
83 |
-
" print_colored(f\"Loaded audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
84 |
-
" \n",
|
85 |
-
" if TWO_SPEAKER:\n",
|
86 |
-
" if audio_tensor.shape[0] == 1:\n",
|
87 |
-
" print_colored(\"Converting mono to stereo...\", \"grey\")\n",
|
88 |
-
" audio_tensor = audio_tensor.repeat(2, 1)\n",
|
89 |
-
" print_colored(f\"Stereo audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
90 |
-
" else:\n",
|
91 |
-
" if audio_tensor.shape[0] == 2:\n",
|
92 |
-
" print_colored(\"Converting stereo to mono...\", \"grey\")\n",
|
93 |
-
" audio_tensor = audio_tensor.mean(dim=0).unsqueeze(0)\n",
|
94 |
-
" print_colored(f\"Mono audio shape: {audio_tensor.shape}\", \"grey\")\n",
|
95 |
-
" \n",
|
96 |
-
" # Resample to 16kHz if needed\n",
|
97 |
-
" if sr != 16000:\n",
|
98 |
-
" print_colored(f\"Resampling from {sr}Hz to 16000Hz...\", \"grey\")\n",
|
99 |
-
" resampler = torchaudio.transforms.Resample(orig_freq=sr, new_freq=16000)\n",
|
100 |
-
" audio_tensor = resampler(audio_tensor)\n",
|
101 |
-
" \n",
|
102 |
-
" # Clip to 5 minutes if needed\n",
|
103 |
-
" max_samples = 16000 * 60 * 5\n",
|
104 |
-
" if audio_tensor.shape[1] > max_samples:\n",
|
105 |
-
" print_colored(\"Clipping audio to 5 minutes...\", \"grey\")\n",
|
106 |
-
" audio_tensor = audio_tensor[:, :max_samples]\n",
|
107 |
-
"\n",
|
108 |
-
" \n",
|
109 |
-
" print_colored(\"Audio preprocessing complete!\", \"green\")\n",
|
110 |
-
" return audio_tensor.unsqueeze(0)\n",
|
111 |
-
"\n",
|
112 |
-
"def display_audio(audio_tensor):\n",
|
113 |
-
" audio_tensor = audio_tensor.cpu().squeeze()\n",
|
114 |
-
" if audio_tensor.ndim == 1:\n",
|
115 |
-
" audio_tensor = audio_tensor.unsqueeze(0)\n",
|
116 |
-
" audio_tensor = audio_tensor.float()\n",
|
117 |
-
"\n",
|
118 |
-
" # Make a waveform plot\n",
|
119 |
-
" plt.figure(figsize=(4, 1))\n",
|
120 |
-
" plt.plot(audio_tensor.numpy()[0], linewidth=0.5)\n",
|
121 |
-
" plt.axis('off')\n",
|
122 |
-
" plt.show()\n",
|
123 |
-
"\n",
|
124 |
-
" # Make an audio player\n",
|
125 |
-
" display(Audio(audio_tensor.numpy(), rate=16000))\n",
|
126 |
-
" print_colored(f\"Audio ready for playback ↑\", \"green\", bold=True)\n",
|
127 |
-
" \n",
|
128 |
-
" \n",
|
129 |
-
"\n",
|
130 |
-
"# Our model is very prompt-sensitive, so we recommend experimenting with a diverse set of prompts.\n",
|
131 |
-
"prompt_audio = load_and_preprocess_audio('./prompts/toaskanymore.wav')\n",
|
132 |
-
"display_audio(prompt_audio)\n",
|
133 |
-
"prompt_len_seconds = 3\n",
|
134 |
-
"prompt_len = prompt_len_seconds * 8"
|
135 |
-
]
|
136 |
-
},
|
137 |
-
{
|
138 |
-
"cell_type": "code",
|
139 |
-
"execution_count": null,
|
140 |
-
"metadata": {},
|
141 |
-
"outputs": [],
|
142 |
-
"source": [
|
143 |
-
"print_colored(\"Encoding prompt...\", \"blue\")\n",
|
144 |
-
"with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
|
145 |
-
" if TWO_SPEAKER:\n",
|
146 |
-
" encoded_prompt_audio_ch1 = audio_tokenizer.latent_from_data(prompt_audio[:, 0:1].to(device))\n",
|
147 |
-
" encoded_prompt_audio_ch2 = audio_tokenizer.latent_from_data(prompt_audio[:, 1:2].to(device))\n",
|
148 |
-
" encoded_prompt_audio = T.cat([encoded_prompt_audio_ch1, encoded_prompt_audio_ch2], dim=-1)\n",
|
149 |
-
" else:\n",
|
150 |
-
" encoded_prompt_audio = audio_tokenizer.latent_from_data(prompt_audio.to(device))\n",
|
151 |
-
"print_colored(f\"Encoded prompt shape: {encoded_prompt_audio.shape}\", \"grey\")\n",
|
152 |
-
"print_colored(\"Prompt encoded successfully!\", \"green\")"
|
153 |
-
]
|
154 |
-
},
|
155 |
-
{
|
156 |
-
"cell_type": "code",
|
157 |
-
"execution_count": null,
|
158 |
-
"metadata": {},
|
159 |
-
"outputs": [],
|
160 |
-
"source": [
|
161 |
-
"def get_completion(encoded_prompt_audio, prompt_len):\n",
|
162 |
-
" prompt_len_seconds = prompt_len / 8\n",
|
163 |
-
" print_colored(f\"Prompt length: {prompt_len_seconds:.2f}s\", \"grey\")\n",
|
164 |
-
" print_colored(\"Completing audio...\", \"blue\")\n",
|
165 |
-
" with T.autocast(device_type='cuda', dtype=T.bfloat16):\n",
|
166 |
-
" completed_audio_batch = generator.completion(\n",
|
167 |
-
" encoded_prompt_audio, \n",
|
168 |
-
" temps=(.8, (0.5, 0.1)), # (token_temp, (categorical_temp, gaussian_temp))\n",
|
169 |
-
" use_cache=True)\n",
|
170 |
-
"\n",
|
171 |
-
" completed_audio = completed_audio_batch\n",
|
172 |
-
" print_colored(f\"Decoding completion...\", \"blue\")\n",
|
173 |
-
" if TWO_SPEAKER:\n",
|
174 |
-
" decoded_completion_ch1 = audio_tokenizer.data_from_latent(completed_audio[:, :, :32].bfloat16())\n",
|
175 |
-
" decoded_completion_ch2 = audio_tokenizer.data_from_latent(completed_audio[:, :, 32:].bfloat16())\n",
|
176 |
-
" decoded_completion = T.cat([decoded_completion_ch1, decoded_completion_ch2], dim=0)\n",
|
177 |
-
" else:\n",
|
178 |
-
" decoded_completion = audio_tokenizer.data_from_latent(completed_audio.bfloat16())\n",
|
179 |
-
" print_colored(f\"Decoded completion shape: {decoded_completion.shape}\", \"grey\")\n",
|
180 |
-
"\n",
|
181 |
-
" print_colored(\"Preparing audio for playback...\", \"blue\")\n",
|
182 |
-
"\n",
|
183 |
-
" audio_tensor = decoded_completion.cpu().squeeze()\n",
|
184 |
-
" if audio_tensor.ndim == 1:\n",
|
185 |
-
" audio_tensor = audio_tensor.unsqueeze(0)\n",
|
186 |
-
" audio_tensor = audio_tensor.float()\n",
|
187 |
-
"\n",
|
188 |
-
" if audio_tensor.abs().max() > 1:\n",
|
189 |
-
" audio_tensor = audio_tensor / audio_tensor.abs().max()\n",
|
190 |
-
"\n",
|
191 |
-
" return audio_tensor[:, max(prompt_len*2000 - 16000, 0):]\n",
|
192 |
-
"\n",
|
193 |
-
"num_completions = 10\n",
|
194 |
-
"print_colored(f\"Generating {num_completions} completions...\", \"blue\")\n",
|
195 |
-
"for _ in range(num_completions):\n",
|
196 |
-
" completion = get_completion(encoded_prompt_audio, prompt_len)\n",
|
197 |
-
" display_audio(completion)"
|
198 |
-
]
|
199 |
-
},
|
200 |
-
{
|
201 |
-
"cell_type": "code",
|
202 |
-
"execution_count": null,
|
203 |
-
"metadata": {},
|
204 |
-
"outputs": [],
|
205 |
-
"source": []
|
206 |
-
},
|
207 |
-
{
|
208 |
-
"cell_type": "code",
|
209 |
-
"execution_count": null,
|
210 |
-
"metadata": {},
|
211 |
-
"outputs": [],
|
212 |
-
"source": []
|
213 |
-
}
|
214 |
-
],
|
215 |
-
"metadata": {
|
216 |
-
"kernelspec": {
|
217 |
-
"display_name": ".venv",
|
218 |
-
"language": "python",
|
219 |
-
"name": "python3"
|
220 |
-
},
|
221 |
-
"language_info": {
|
222 |
-
"codemirror_mode": {
|
223 |
-
"name": "ipython",
|
224 |
-
"version": 3
|
225 |
-
},
|
226 |
-
"file_extension": ".py",
|
227 |
-
"mimetype": "text/x-python",
|
228 |
-
"name": "python",
|
229 |
-
"nbconvert_exporter": "python",
|
230 |
-
"pygments_lexer": "ipython3",
|
231 |
-
"version": "3.10.12"
|
232 |
-
}
|
233 |
-
},
|
234 |
-
"nbformat": 4,
|
235 |
-
"nbformat_minor": 2
|
236 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_client.py
DELETED
@@ -1,161 +0,0 @@
|
|
1 |
-
# server.py remains the same as before
|
2 |
-
|
3 |
-
# Updated client.py
|
4 |
-
import asyncio
|
5 |
-
import websockets
|
6 |
-
import sounddevice as sd
|
7 |
-
import numpy as np
|
8 |
-
import base64
|
9 |
-
import queue
|
10 |
-
import argparse
|
11 |
-
import requests
|
12 |
-
import time
|
13 |
-
|
14 |
-
class AudioClient:
|
15 |
-
def __init__(self, server_url="ws://localhost:8000", token_temp=None, categorical_temp=None, gaussian_temp=None):
|
16 |
-
# Convert ws:// to http:// for the base URL
|
17 |
-
self.base_url = server_url.replace("ws://", "http://")
|
18 |
-
self.server_url = f"{server_url}/audio"
|
19 |
-
|
20 |
-
# Set temperatures if provided
|
21 |
-
if any(t is not None for t in [token_temp, categorical_temp, gaussian_temp]):
|
22 |
-
self.set_temperature_and_echo(token_temp, categorical_temp, gaussian_temp)
|
23 |
-
|
24 |
-
# Initialize queues
|
25 |
-
self.audio_queue = queue.Queue()
|
26 |
-
self.output_queue = queue.Queue()
|
27 |
-
|
28 |
-
def set_temperature_and_echo(self, token_temp=None, categorical_temp=None, gaussian_temp=None, echo_testing = False):
|
29 |
-
"""Send temperature settings to server"""
|
30 |
-
params = {}
|
31 |
-
if token_temp is not None:
|
32 |
-
params['token_temp'] = token_temp
|
33 |
-
if categorical_temp is not None:
|
34 |
-
params['categorical_temp'] = categorical_temp
|
35 |
-
if gaussian_temp is not None:
|
36 |
-
params['gaussian_temp'] = gaussian_temp
|
37 |
-
|
38 |
-
response = requests.post(f"{self.base_url}/set_temperature", params=params)
|
39 |
-
print(response.json()['message'])
|
40 |
-
|
41 |
-
def audio_callback(self, indata, frames, time, status):
|
42 |
-
"""This is called for each audio block"""
|
43 |
-
if status:
|
44 |
-
print(status)
|
45 |
-
# if np.isclose(indata, 0).all():
|
46 |
-
# raise Exception('Audio input is not working - received all zeros')
|
47 |
-
# Convert float32 to int16 for efficient transmission
|
48 |
-
indata_int16 = (indata.copy() * 32767).astype(np.int16)
|
49 |
-
# indata_int16 = np.zeros_like(indata_int16)
|
50 |
-
self.audio_queue.put(indata_int16)
|
51 |
-
|
52 |
-
def output_stream_callback(self, outdata, frames, time, status):
|
53 |
-
"""Callback for output stream to get audio data"""
|
54 |
-
if status:
|
55 |
-
print(status)
|
56 |
-
|
57 |
-
try:
|
58 |
-
data = self.output_queue.get_nowait()
|
59 |
-
data = data.astype(np.float32) / 32767.0
|
60 |
-
if len(data) < len(outdata):
|
61 |
-
outdata[:len(data)] = data
|
62 |
-
outdata[len(data):] = 0
|
63 |
-
else:
|
64 |
-
outdata[:] = data[:len(outdata)]
|
65 |
-
except queue.Empty:
|
66 |
-
outdata.fill(0)
|
67 |
-
|
68 |
-
async def process_audio(self):
|
69 |
-
async with websockets.connect(self.server_url) as ws:
|
70 |
-
while self.running:
|
71 |
-
if not self.audio_queue.empty():
|
72 |
-
# Get recorded audio
|
73 |
-
audio_data = self.audio_queue.get()
|
74 |
-
print(f'Data from microphone:{audio_data.shape, audio_data.dtype, audio_data.min(), audio_data.max()}')
|
75 |
-
|
76 |
-
# Convert to base64
|
77 |
-
audio_b64 = base64.b64encode(audio_data.tobytes()).decode('utf-8')
|
78 |
-
|
79 |
-
# Send to server
|
80 |
-
time_sent = time.time()
|
81 |
-
await ws.send(f"data:audio/raw;base64,{audio_b64}")
|
82 |
-
|
83 |
-
# Receive processed audio
|
84 |
-
response = await ws.recv()
|
85 |
-
response = response.split(",")[1]
|
86 |
-
time_received = time.time()
|
87 |
-
print(f"Data sent: {audio_b64[:10]}. Data received: {response[:10]}. Received in {(time_received - time_sent) * 1000:.2f} ms")
|
88 |
-
processed_audio = np.frombuffer(
|
89 |
-
base64.b64decode(response),
|
90 |
-
dtype=np.int16
|
91 |
-
).reshape(-1, CHANNELS)
|
92 |
-
print(f'Data from model:{processed_audio.shape, processed_audio.dtype, processed_audio.min(), processed_audio.max()}')
|
93 |
-
|
94 |
-
self.output_queue.put(processed_audio)
|
95 |
-
|
96 |
-
def start(self):
|
97 |
-
self.running = True
|
98 |
-
# Print audio device information
|
99 |
-
devices = sd.query_devices()
|
100 |
-
default_input = sd.query_devices(kind='input')
|
101 |
-
default_output = sd.query_devices(kind='output')
|
102 |
-
|
103 |
-
print("\nAudio Device Configuration:")
|
104 |
-
print("-" * 50)
|
105 |
-
print(f"Default Input Device:\n{default_input}\n")
|
106 |
-
print(f"Default Output Device:\n{default_output}\n")
|
107 |
-
print("\nAll Available Devices:")
|
108 |
-
print("-" * 50)
|
109 |
-
for i, device in enumerate(devices):
|
110 |
-
print(f"Device {i}:")
|
111 |
-
print(f"Name: {device['name']}")
|
112 |
-
print(f"Channels (in/out): {device['max_input_channels']}/{device['max_output_channels']}")
|
113 |
-
print(f"Sample Rates: {device['default_samplerate']}")
|
114 |
-
print()
|
115 |
-
input_device = input("Enter the index of the input device or press enter for default: ")
|
116 |
-
output_device = input("Enter the index of the output device or press enter for default: ")
|
117 |
-
if input_device == "":
|
118 |
-
input_device = default_input['index']
|
119 |
-
if output_device == "":
|
120 |
-
output_device = default_output['index']
|
121 |
-
with sd.InputStream(callback=self.audio_callback,
|
122 |
-
channels=CHANNELS,
|
123 |
-
samplerate=SAMPLE_RATE,
|
124 |
-
device=int(input_device),
|
125 |
-
blocksize=2000), \
|
126 |
-
sd.OutputStream(callback=self.output_stream_callback,
|
127 |
-
channels=CHANNELS,
|
128 |
-
samplerate=SAMPLE_RATE,
|
129 |
-
blocksize=2000,
|
130 |
-
device=int(output_device)):
|
131 |
-
|
132 |
-
asyncio.run(self.process_audio())
|
133 |
-
|
134 |
-
def stop(self):
|
135 |
-
self.running = False
|
136 |
-
|
137 |
-
if __name__ == "__main__":
|
138 |
-
parser = argparse.ArgumentParser(description='Audio Client with Temperature Control')
|
139 |
-
parser.add_argument('--token_temp', '-t1', type=float, help='Token (LM) temperature parameter')
|
140 |
-
parser.add_argument('--categorical_temp', '-t2', type=float, help='Categorical (VAE) temperature parameter')
|
141 |
-
parser.add_argument('--gaussian_temp', '-t3', type=float, help='Gaussian (VAE) temperature parameter')
|
142 |
-
parser.add_argument('--server', '-s', default="ws://localhost:8000",
|
143 |
-
help='Server URL (default: ws://localhost:8000)')
|
144 |
-
|
145 |
-
args = parser.parse_args()
|
146 |
-
|
147 |
-
# Audio settings
|
148 |
-
SAMPLE_RATE = 16000
|
149 |
-
CHANNELS = 1
|
150 |
-
|
151 |
-
client = AudioClient(
|
152 |
-
server_url=args.server,
|
153 |
-
token_temp=args.token_temp,
|
154 |
-
categorical_temp=args.categorical_temp,
|
155 |
-
gaussian_temp=args.gaussian_temp
|
156 |
-
)
|
157 |
-
|
158 |
-
try:
|
159 |
-
client.start()
|
160 |
-
except KeyboardInterrupt:
|
161 |
-
client.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
inference_server.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import numpy as np
|
3 |
-
from fastapi import FastAPI, WebSocket
|
4 |
-
from fastapi.middleware.cors import CORSMiddleware
|
5 |
-
import base64
|
6 |
-
import uvicorn
|
7 |
-
import traceback
|
8 |
-
import numpy as np
|
9 |
-
import argparse
|
10 |
-
|
11 |
-
import torch as T
|
12 |
-
import torch.nn.functional as F
|
13 |
-
import torchaudio
|
14 |
-
|
15 |
-
import os
|
16 |
-
from typing import Optional
|
17 |
-
|
18 |
-
from utils import print_colored
|
19 |
-
from model import get_hertz_dev_config
|
20 |
-
|
21 |
-
|
22 |
-
argparse = argparse.ArgumentParser()
|
23 |
-
|
24 |
-
argparse.add_argument('--prompt_path', type=str, default='./prompts/bob_mono.wav', help="""
|
25 |
-
We highly recommend making your own prompt based on a conversation between you and another person.
|
26 |
-
bob_mono.wav seems to work better for two-channel than bob_stereo.wav.
|
27 |
-
""")
|
28 |
-
args = argparse.parse_args()
|
29 |
-
|
30 |
-
|
31 |
-
device = 'cuda' if T.cuda.is_available() else T.device('cpu')
|
32 |
-
print_colored(f"Using device: {device}", "grey")
|
33 |
-
|
34 |
-
model_config = get_hertz_dev_config(is_split=True)
|
35 |
-
|
36 |
-
model = model_config()
|
37 |
-
model = model.eval().bfloat16().to(device)
|
38 |
-
|
39 |
-
app = FastAPI()
|
40 |
-
|
41 |
-
app.add_middleware(
|
42 |
-
CORSMiddleware,
|
43 |
-
allow_origins=["*"],
|
44 |
-
allow_credentials=True,
|
45 |
-
allow_methods=["*"],
|
46 |
-
allow_headers=["*"],
|
47 |
-
)
|
48 |
-
|
49 |
-
|
50 |
-
# Hyperparams or something.
|
51 |
-
SAMPLE_RATE = 16000 # Don't change this
|
52 |
-
TEMPS = (0.8, (0.4, 0.1)) # You can change this, but there's also an endpoint for it.
|
53 |
-
|
54 |
-
class AudioProcessor:
|
55 |
-
def __init__(self, model, prompt_path):
|
56 |
-
self.model = model
|
57 |
-
self.prompt_path = prompt_path
|
58 |
-
self.initialize_state(prompt_path)
|
59 |
-
|
60 |
-
def initialize_state(self, prompt_path):
|
61 |
-
loaded_audio, sr = torchaudio.load(prompt_path)
|
62 |
-
|
63 |
-
if sr != SAMPLE_RATE:
|
64 |
-
resampler = torchaudio.transforms.Resample(sr, SAMPLE_RATE)
|
65 |
-
loaded_audio = resampler(loaded_audio)
|
66 |
-
|
67 |
-
if loaded_audio.shape[0] == 1:
|
68 |
-
loaded_audio = loaded_audio.repeat(2, 1)
|
69 |
-
|
70 |
-
audio_length = loaded_audio.shape[-1]
|
71 |
-
num_chunks = audio_length // 2000
|
72 |
-
loaded_audio = loaded_audio[..., :num_chunks * 2000]
|
73 |
-
|
74 |
-
self.loaded_audio = loaded_audio.to(device)
|
75 |
-
|
76 |
-
with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
|
77 |
-
self.model.init_cache(bsize=1, device=device, dtype=T.bfloat16, length=1024)
|
78 |
-
self.next_model_audio = self.model.next_audio_from_audio(self.loaded_audio.unsqueeze(0), temps=TEMPS)
|
79 |
-
self.prompt_buffer = None
|
80 |
-
self.prompt_position = 0
|
81 |
-
self.chunks_until_live = 24
|
82 |
-
self.initialize_prompt_buffer()
|
83 |
-
print_colored("AudioProcessor state initialized", "green")
|
84 |
-
|
85 |
-
def initialize_prompt_buffer(self):
|
86 |
-
self.recorded_audio = self.loaded_audio
|
87 |
-
prompt_audio = self.loaded_audio.reshape(1, 2, -1)
|
88 |
-
prompt_audio = prompt_audio[:, :, -48000:].cpu().numpy()
|
89 |
-
prompt_audio_mono = prompt_audio.mean(axis=1)
|
90 |
-
self.prompt_buffer = np.array_split(prompt_audio_mono[0], 24)
|
91 |
-
print_colored(f"Initialized prompt buffer with {len(self.prompt_buffer)} chunks", "grey")
|
92 |
-
|
93 |
-
async def process_audio(self, audio_data):
|
94 |
-
if self.chunks_until_live > 0:
|
95 |
-
print_colored(f"Serving from prompt buffer, {self.chunks_until_live} chunks left", "grey")
|
96 |
-
chunk = self.prompt_buffer[24 - self.chunks_until_live]
|
97 |
-
self.chunks_until_live -= 1
|
98 |
-
|
99 |
-
if self.chunks_until_live == 0:
|
100 |
-
print_colored("Switching to live processing mode", "green")
|
101 |
-
|
102 |
-
time.sleep(0.05)
|
103 |
-
return chunk
|
104 |
-
|
105 |
-
audio_tensor = T.from_numpy(audio_data).to(device)
|
106 |
-
audio_tensor = audio_tensor.reshape(1, 1, -1)
|
107 |
-
audio_tensor = T.cat([audio_tensor, self.next_model_audio], dim=1)
|
108 |
-
|
109 |
-
with T.autocast(device_type=device, dtype=T.bfloat16), T.inference_mode():
|
110 |
-
curr_model_audio = self.model.next_audio_from_audio(
|
111 |
-
audio_tensor,
|
112 |
-
temps=TEMPS
|
113 |
-
)
|
114 |
-
print(f"Recorded audio shape {self.recorded_audio.shape}, audio tensor shape {audio_tensor.shape}")
|
115 |
-
self.recorded_audio = T.cat([self.recorded_audio.cpu(), audio_tensor.squeeze(0).cpu()], dim=-1)
|
116 |
-
|
117 |
-
self.next_model_audio = curr_model_audio
|
118 |
-
|
119 |
-
return curr_model_audio.float().cpu().numpy()
|
120 |
-
|
121 |
-
def cleanup(self):
|
122 |
-
print_colored("Cleaning up audio processor...", "blue")
|
123 |
-
os.makedirs('audio_recordings', exist_ok=True)
|
124 |
-
torchaudio.save(f'audio_recordings/{time.strftime("%d-%H-%M")}.wav', self.recorded_audio.cpu(), SAMPLE_RATE)
|
125 |
-
self.model.deinit_cache()
|
126 |
-
self.initialize_state(self.prompt_path)
|
127 |
-
print_colored("Audio processor cleanup complete", "green")
|
128 |
-
|
129 |
-
@app.post("/set_temperature")
|
130 |
-
async def set_temperature(token_temp: Optional[float] = None, categorical_temp: Optional[float] = None, gaussian_temp: Optional[float] = None):
|
131 |
-
try:
|
132 |
-
global TEMPS
|
133 |
-
TEMPS = (token_temp, (categorical_temp, gaussian_temp))
|
134 |
-
|
135 |
-
print_colored(f"Temperature updated to: {TEMPS}", "green")
|
136 |
-
return {"message": f"Temperature updated to: {TEMPS}", "status": "success"}
|
137 |
-
except Exception as e:
|
138 |
-
print_colored(f"Error setting temperature: {str(e)}", "red")
|
139 |
-
return {"message": f"Error setting temperature: {str(e)}", "status": "error"}
|
140 |
-
|
141 |
-
@app.websocket("/audio")
|
142 |
-
async def websocket_endpoint(websocket: WebSocket):
|
143 |
-
await websocket.accept()
|
144 |
-
try:
|
145 |
-
while True:
|
146 |
-
data = await websocket.receive_text()
|
147 |
-
audio_data = np.frombuffer(
|
148 |
-
base64.b64decode(data.split(",")[1]),
|
149 |
-
dtype=np.int16
|
150 |
-
)
|
151 |
-
audio_data = audio_data.astype(np.float32) / 32767.0
|
152 |
-
processed_audio = await audio_processor.process_audio(audio_data)
|
153 |
-
processed_audio = (processed_audio * 32767).astype(np.int16)
|
154 |
-
|
155 |
-
processed_data = base64.b64encode(processed_audio.tobytes()).decode('utf-8')
|
156 |
-
await websocket.send_text(f"data:audio/raw;base64,{processed_data}")
|
157 |
-
|
158 |
-
except Exception as e:
|
159 |
-
print_colored(f"WebSocket error: {e}", "red")
|
160 |
-
print_colored(f"Full traceback:\n{traceback.format_exc()}", "red")
|
161 |
-
finally:
|
162 |
-
audio_processor.cleanup()
|
163 |
-
await websocket.close()
|
164 |
-
|
165 |
-
|
166 |
-
audio_processor = AudioProcessor(model=model, prompt_path=args.prompt_path)
|
167 |
-
|
168 |
-
if __name__ == "__main__":
|
169 |
-
uvicorn.run(app, host="0.0.0.0", port=8000)
|
170 |
-
print("Server started")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ioblocks.py
DELETED
@@ -1,333 +0,0 @@
|
|
1 |
-
from __future__ import annotations
|
2 |
-
from functools import partial
|
3 |
-
from contextlib import nullcontext
|
4 |
-
from typing import List, Tuple
|
5 |
-
from math import ceil
|
6 |
-
|
7 |
-
import torch as T
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
import torch.distributed as dist
|
11 |
-
from torch import Tensor, int32
|
12 |
-
from torch.amp import autocast
|
13 |
-
|
14 |
-
from einops import rearrange, pack, unpack
|
15 |
-
|
16 |
-
|
17 |
-
from utils import si_module, exists, default, maybe
|
18 |
-
|
19 |
-
|
20 |
-
@si_module
|
21 |
-
class GaussianMixtureIOLayer(nn.Module):
|
22 |
-
class Config:
|
23 |
-
latent_dim: int
|
24 |
-
dim: int
|
25 |
-
num_components: int
|
26 |
-
|
27 |
-
def __init__(self, c: Config):
|
28 |
-
super().__init__()
|
29 |
-
self.latent_dim = c.latent_dim
|
30 |
-
self.num_components = c.num_components
|
31 |
-
self.input_projection = nn.Linear(c.latent_dim, c.dim)
|
32 |
-
|
33 |
-
self.fc_loc = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
34 |
-
self.fc_scale = nn.Linear(c.dim, c.num_components * c.latent_dim)
|
35 |
-
self.fc_weight = nn.Linear(c.dim, c.num_components)
|
36 |
-
|
37 |
-
def _square_plus(self, x):
|
38 |
-
return (x + T.sqrt(T.square(x) + 4)) / 2
|
39 |
-
|
40 |
-
def input(self, sampled_latents: T.Tensor) -> T.Tensor:
|
41 |
-
"""Pre-sampled latents T.Tensor (B, L, Z) -> float tensor (B, L, D)"""
|
42 |
-
hidden = self.input_projection(sampled_latents)
|
43 |
-
return hidden
|
44 |
-
|
45 |
-
def output(self, h: T.Tensor) -> Tuple[T.Tensor, T.Tensor, T.Tensor]:
|
46 |
-
"""float tensor (B, L, D) -> Tuple of locs, scales, and weights"""
|
47 |
-
batch_size, seq_len, _ = h.shape
|
48 |
-
|
49 |
-
locs = self.fc_loc(h).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
50 |
-
scales = T.clamp(self._square_plus(self.fc_scale(h)), min=1e-6).view(batch_size, seq_len, self.num_components, self.latent_dim)
|
51 |
-
weights = self.fc_weight(h).view(batch_size, seq_len, self.num_components)
|
52 |
-
|
53 |
-
return (locs, scales, weights)
|
54 |
-
|
55 |
-
def loss(self, data, dataHat):
|
56 |
-
locs, scales, weights = dataHat
|
57 |
-
log_probs = -0.5 * T.sum(
|
58 |
-
(data.unsqueeze(-2) - locs).pow(2) / scales.pow(2) +
|
59 |
-
2 * T.log(scales) +
|
60 |
-
T.log(T.tensor(2 * T.pi)),
|
61 |
-
dim=-1
|
62 |
-
)
|
63 |
-
log_weights = F.log_softmax(weights, dim=-1)
|
64 |
-
return -T.logsumexp(log_weights + log_probs, dim=-1)
|
65 |
-
|
66 |
-
|
67 |
-
def temp_sample(self, orig_pdist, temp):
|
68 |
-
locs, scales, weights = orig_pdist
|
69 |
-
if temp is None:
|
70 |
-
component_samples = locs + scales * T.randn_like(scales)
|
71 |
-
mixture_samples = F.gumbel_softmax(weights, hard=True)
|
72 |
-
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
73 |
-
elif isinstance(temp, tuple):
|
74 |
-
assert len(temp) == 2
|
75 |
-
categorical_temp, gaussian_temp = temp
|
76 |
-
component_samples = locs + scales * gaussian_temp * T.randn_like(scales)
|
77 |
-
mixture_samples = F.gumbel_softmax(weights / categorical_temp, hard=True)
|
78 |
-
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
79 |
-
else:
|
80 |
-
component_samples = locs + scales * temp * T.randn_like(scales)
|
81 |
-
mixture_samples = F.gumbel_softmax(weights / temp, hard=True)
|
82 |
-
sampled = (component_samples * mixture_samples.unsqueeze(-1)).sum(dim=-2)
|
83 |
-
return sampled
|
84 |
-
|
85 |
-
|
86 |
-
class GPTOutput(nn.Module):
|
87 |
-
def __init__(self, dim, vocab_size):
|
88 |
-
super().__init__()
|
89 |
-
self.output = nn.Linear(dim, vocab_size, bias=False)
|
90 |
-
|
91 |
-
def forward(self, x):
|
92 |
-
return self.output(x)
|
93 |
-
|
94 |
-
|
95 |
-
# helper functions
|
96 |
-
|
97 |
-
def pack_one(t, pattern):
|
98 |
-
return pack([t], pattern)
|
99 |
-
|
100 |
-
def unpack_one(t, ps, pattern):
|
101 |
-
return unpack(t, ps, pattern)[0]
|
102 |
-
|
103 |
-
def first(l):
|
104 |
-
return l[0]
|
105 |
-
|
106 |
-
def round_up_multiple(num, mult):
|
107 |
-
return ceil(num / mult) * mult
|
108 |
-
|
109 |
-
def get_code_utilization(codes, codebook_size, get_global=False):
|
110 |
-
if get_global and dist.is_initialized():
|
111 |
-
world_size = dist.get_world_size()
|
112 |
-
else:
|
113 |
-
world_size = 1
|
114 |
-
|
115 |
-
if world_size > 1:
|
116 |
-
gathered_tokens = [T.zeros_like(codes) for _ in range(world_size)]
|
117 |
-
dist.all_gather(gathered_tokens, codes)
|
118 |
-
gathered_tokens = T.cat(gathered_tokens, dim=0)
|
119 |
-
else:
|
120 |
-
gathered_tokens = codes
|
121 |
-
unique_tokens = len(T.unique(gathered_tokens))
|
122 |
-
code_utilization = unique_tokens / min(gathered_tokens.numel(), codebook_size)
|
123 |
-
return code_utilization
|
124 |
-
|
125 |
-
# tensor helpers
|
126 |
-
|
127 |
-
def round_ste(z: Tensor) -> Tensor:
|
128 |
-
"""Round with straight through gradients."""
|
129 |
-
zhat = z.round()
|
130 |
-
return z + (zhat - z).detach()
|
131 |
-
|
132 |
-
# main class
|
133 |
-
# lucidrains fsq
|
134 |
-
@si_module
|
135 |
-
class FSQ(nn.Module):
|
136 |
-
@property
|
137 |
-
def needs_float32_params(self):
|
138 |
-
return True
|
139 |
-
|
140 |
-
class Config:
|
141 |
-
levels: List[int]
|
142 |
-
dim: int | None = None
|
143 |
-
num_codebooks: int = 1
|
144 |
-
keep_num_codebooks_dim: bool | None = None
|
145 |
-
scale: float | None = None
|
146 |
-
allowed_dtypes: Tuple[str, ...] = ('float32', 'float64')
|
147 |
-
channel_first: bool = False
|
148 |
-
projection_has_bias: bool = True
|
149 |
-
return_indices: bool = True
|
150 |
-
force_quantization_f32: bool = True
|
151 |
-
use_rms: bool = False
|
152 |
-
|
153 |
-
def __init__(self, c: Config):
|
154 |
-
super().__init__()
|
155 |
-
_levels = T.tensor(c.levels, dtype=int32)
|
156 |
-
self.register_buffer("_levels", _levels, persistent = False)
|
157 |
-
|
158 |
-
_basis = T.cumprod(T.tensor([1] + c.levels[:-1]), dim=0, dtype=int32)
|
159 |
-
self.register_buffer("_basis", _basis, persistent = False)
|
160 |
-
|
161 |
-
self.scale = c.scale
|
162 |
-
|
163 |
-
codebook_dim = len(c.levels)
|
164 |
-
self.codebook_dim = codebook_dim
|
165 |
-
|
166 |
-
effective_codebook_dim = codebook_dim * c.num_codebooks
|
167 |
-
self.num_codebooks = c.num_codebooks
|
168 |
-
|
169 |
-
self.allowed_dtypes = []
|
170 |
-
for dtype_str in c.allowed_dtypes:
|
171 |
-
if hasattr(T, dtype_str):
|
172 |
-
self.allowed_dtypes.append(getattr(T, dtype_str))
|
173 |
-
else:
|
174 |
-
raise ValueError(f"Invalid dtype string: {dtype_str}")
|
175 |
-
|
176 |
-
self.effective_codebook_dim = effective_codebook_dim
|
177 |
-
|
178 |
-
keep_num_codebooks_dim = default(c.keep_num_codebooks_dim, c.num_codebooks > 1)
|
179 |
-
assert not (c.num_codebooks > 1 and not keep_num_codebooks_dim)
|
180 |
-
self.keep_num_codebooks_dim = keep_num_codebooks_dim
|
181 |
-
|
182 |
-
self.dim = default(c.dim, len(_levels) * c.num_codebooks)
|
183 |
-
|
184 |
-
self.channel_first = c.channel_first
|
185 |
-
|
186 |
-
has_projections = self.dim != effective_codebook_dim
|
187 |
-
self.project_in = nn.Linear(self.dim, effective_codebook_dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
188 |
-
self.project_out = nn.Linear(effective_codebook_dim, self.dim, bias = c.projection_has_bias) if has_projections else nn.Identity()
|
189 |
-
|
190 |
-
self.has_projections = has_projections
|
191 |
-
|
192 |
-
self.return_indices = c.return_indices
|
193 |
-
if c.return_indices:
|
194 |
-
self.codebook_size = self._levels.prod().item()
|
195 |
-
implicit_codebook = self._indices_to_codes(T.arange(self.codebook_size))
|
196 |
-
self.register_buffer("implicit_codebook", implicit_codebook, persistent = False)
|
197 |
-
|
198 |
-
self.allowed_dtypes = c.allowed_dtypes
|
199 |
-
self.force_quantization_f32 = c.force_quantization_f32
|
200 |
-
|
201 |
-
self.latent_loss = None
|
202 |
-
|
203 |
-
def latent_metric(self, codes, get_global=False):
|
204 |
-
return {'code_util_estimate': get_code_utilization(codes, self.codebook_size, get_global)}
|
205 |
-
|
206 |
-
def repr_from_latent(self, latent):
|
207 |
-
return self.indices_to_codes(latent)
|
208 |
-
|
209 |
-
def bound(self, z, eps: float = 1e-3):
|
210 |
-
""" Bound `z`, an array of shape (..., d). """
|
211 |
-
half_l = (self._levels - 1) * (1 + eps) / 2
|
212 |
-
offset = T.where(self._levels % 2 == 0, 0.5, 0.0)
|
213 |
-
shift = (offset / half_l).atanh()
|
214 |
-
return (z + shift).tanh() * half_l - offset
|
215 |
-
|
216 |
-
def quantize(self, z):
|
217 |
-
""" Quantizes z, returns quantized zhat, same shape as z. """
|
218 |
-
quantized = round_ste(self.bound(z))
|
219 |
-
half_width = self._levels // 2 # Renormalize to [-1, 1].
|
220 |
-
return quantized / half_width
|
221 |
-
|
222 |
-
def _scale_and_shift(self, zhat_normalized):
|
223 |
-
half_width = self._levels // 2
|
224 |
-
return (zhat_normalized * half_width) + half_width
|
225 |
-
|
226 |
-
def _scale_and_shift_inverse(self, zhat):
|
227 |
-
half_width = self._levels // 2
|
228 |
-
return (zhat - half_width) / half_width
|
229 |
-
|
230 |
-
def _indices_to_codes(self, indices):
|
231 |
-
level_indices = self.indices_to_level_indices(indices)
|
232 |
-
codes = self._scale_and_shift_inverse(level_indices)
|
233 |
-
return codes
|
234 |
-
|
235 |
-
def codes_to_indices(self, zhat):
|
236 |
-
""" Converts a `code` to an index in the codebook. """
|
237 |
-
assert zhat.shape[-1] == self.codebook_dim
|
238 |
-
zhat = self._scale_and_shift(zhat)
|
239 |
-
return (zhat * self._basis).sum(dim=-1).to(int32)
|
240 |
-
|
241 |
-
def indices_to_level_indices(self, indices):
|
242 |
-
""" Converts indices to indices at each level, perhaps needed for a transformer with factorized embeddings """
|
243 |
-
indices = rearrange(indices, '... -> ... 1')
|
244 |
-
codes_non_centered = (indices // self._basis) % self._levels
|
245 |
-
return codes_non_centered
|
246 |
-
|
247 |
-
def indices_to_codes(self, indices):
|
248 |
-
""" Inverse of `codes_to_indices`. """
|
249 |
-
assert exists(indices)
|
250 |
-
|
251 |
-
is_img_or_video = indices.ndim >= (3 + int(self.keep_num_codebooks_dim))
|
252 |
-
|
253 |
-
codes = self._indices_to_codes(indices)
|
254 |
-
|
255 |
-
if self.keep_num_codebooks_dim:
|
256 |
-
codes = rearrange(codes, '... c d -> ... (c d)')
|
257 |
-
|
258 |
-
codes = self.project_out(codes)
|
259 |
-
|
260 |
-
if is_img_or_video or self.channel_first:
|
261 |
-
codes = rearrange(codes, 'b ... d -> b d ...')
|
262 |
-
|
263 |
-
return codes
|
264 |
-
|
265 |
-
# @autocast(device_type='cuda', enabled = False)
|
266 |
-
def forward(self, z, return_codes=False):
|
267 |
-
"""
|
268 |
-
einstein notation
|
269 |
-
b - batch
|
270 |
-
n - sequence (or flattened spatial dimensions)
|
271 |
-
d - feature dimension
|
272 |
-
c - number of codebook dim
|
273 |
-
"""
|
274 |
-
|
275 |
-
is_img_or_video = z.ndim >= 4
|
276 |
-
need_move_channel_last = is_img_or_video or self.channel_first
|
277 |
-
|
278 |
-
# standardize image or video into (batch, seq, dimension)
|
279 |
-
|
280 |
-
if need_move_channel_last:
|
281 |
-
z = rearrange(z, 'b d ... -> b ... d')
|
282 |
-
z, ps = pack_one(z, 'b * d')
|
283 |
-
|
284 |
-
assert z.shape[-1] == self.dim, f'expected dimension of {self.dim} but found dimension of {z.shape[-1]}'
|
285 |
-
|
286 |
-
z = self.project_in(z)
|
287 |
-
|
288 |
-
z = rearrange(z, 'b n (c d) -> b n c d', c = self.num_codebooks)
|
289 |
-
|
290 |
-
# whether to force quantization step to be full precision or not
|
291 |
-
|
292 |
-
force_f32 = self.force_quantization_f32
|
293 |
-
quantization_context = partial(autocast, device_type='cuda', enabled = False) if force_f32 else nullcontext
|
294 |
-
|
295 |
-
with quantization_context():
|
296 |
-
orig_dtype = z.dtype
|
297 |
-
|
298 |
-
if force_f32 and orig_dtype not in self.allowed_dtypes:
|
299 |
-
z = z.float()
|
300 |
-
|
301 |
-
codes = self.quantize(z)
|
302 |
-
|
303 |
-
# returning indices could be optional
|
304 |
-
|
305 |
-
indices = None
|
306 |
-
|
307 |
-
if self.return_indices:
|
308 |
-
indices = self.codes_to_indices(codes)
|
309 |
-
|
310 |
-
codes = rearrange(codes, 'b n c d -> b n (c d)')
|
311 |
-
|
312 |
-
codes = codes.type(orig_dtype)
|
313 |
-
|
314 |
-
# project out
|
315 |
-
if return_codes:
|
316 |
-
return codes, indices
|
317 |
-
|
318 |
-
out = self.project_out(codes)
|
319 |
-
|
320 |
-
# reconstitute image or video dimensions
|
321 |
-
|
322 |
-
if need_move_channel_last:
|
323 |
-
out = unpack_one(out, ps, 'b * d')
|
324 |
-
out = rearrange(out, 'b ... d -> b d ...')
|
325 |
-
|
326 |
-
indices = maybe(unpack_one)(indices, ps, 'b * c')
|
327 |
-
|
328 |
-
if not self.keep_num_codebooks_dim and self.return_indices:
|
329 |
-
indices = maybe(rearrange)(indices, '... 1 -> ...')
|
330 |
-
|
331 |
-
# return quantized output and indices
|
332 |
-
|
333 |
-
return out, indices
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model.py
DELETED
@@ -1,441 +0,0 @@
|
|
1 |
-
from typing import Optional, Tuple
|
2 |
-
|
3 |
-
import torch as T
|
4 |
-
import torch.nn as nn
|
5 |
-
import torch.nn.functional as F
|
6 |
-
|
7 |
-
from ioblocks import GaussianMixtureIOLayer, FSQ
|
8 |
-
|
9 |
-
from transformer import Stack, ShapeRotator, Block as PerfBlock, GPTOutput, CACHE_FILL_VALUE, FFNN, Norm
|
10 |
-
from tokenizer import make_tokenizer
|
11 |
-
|
12 |
-
|
13 |
-
from utils import si_module, exists, isnt, tqdm0, print0, default, print0_colored
|
14 |
-
from utils import load_ckpt
|
15 |
-
|
16 |
-
|
17 |
-
@si_module
|
18 |
-
class LatentQuantizer(nn.Module):
|
19 |
-
class Config:
|
20 |
-
compressor_config: Optional[FSQ.Config] = None
|
21 |
-
|
22 |
-
dim: Optional[int] = None
|
23 |
-
ff_dim: Optional[int] = None
|
24 |
-
input_dim: int = None
|
25 |
-
|
26 |
-
from_pretrained: Optional[Tuple[str, str]] = None
|
27 |
-
|
28 |
-
def __init__(self, c: Config):
|
29 |
-
super().__init__()
|
30 |
-
|
31 |
-
if exists(c.from_pretrained):
|
32 |
-
checkpoint = load_ckpt(*c.from_pretrained)
|
33 |
-
else:
|
34 |
-
assert exists(c.compressor_config), f'hmm {c}'
|
35 |
-
|
36 |
-
self.compressor = c.compressor_config()
|
37 |
-
self.ffnn = FFNN(c.dim, c.ff_dim)
|
38 |
-
self.input = nn.Linear(c.input_dim, c.dim) if exists(c.input_dim) else nn.Identity()
|
39 |
-
|
40 |
-
if exists(c.from_pretrained):
|
41 |
-
self.load_state_dict(checkpoint)
|
42 |
-
|
43 |
-
@T.no_grad()
|
44 |
-
def forward(self, x, return_latent=False, known_latent=None):
|
45 |
-
"""
|
46 |
-
x: (B, S, D)
|
47 |
-
"""
|
48 |
-
if exists(known_latent):
|
49 |
-
return self.compressor.indices_to_codes(known_latent)
|
50 |
-
|
51 |
-
x = self.input(x)
|
52 |
-
x = self.ffnn(x)
|
53 |
-
x, tokens = self.compressor(x)
|
54 |
-
|
55 |
-
if return_latent:
|
56 |
-
return x, tokens
|
57 |
-
return x
|
58 |
-
|
59 |
-
|
60 |
-
@si_module
|
61 |
-
class TransformerVAE(nn.Module):
|
62 |
-
class Config:
|
63 |
-
io_config: Optional[GaussianMixtureIOLayer.Config] = None
|
64 |
-
stack_config: Optional[Stack.Config] = None
|
65 |
-
quantizer_config: Optional[LatentQuantizer.Config] = None
|
66 |
-
|
67 |
-
plex_layer: int = None
|
68 |
-
plex_roll: int = 1
|
69 |
-
split: bool = True
|
70 |
-
|
71 |
-
from_pretrained: Optional[Tuple[str, str]] = None
|
72 |
-
|
73 |
-
def __init__(self, c: Config):
|
74 |
-
super().__init__()
|
75 |
-
|
76 |
-
if exists(c.from_pretrained):
|
77 |
-
checkpoint = load_ckpt(*c.from_pretrained)
|
78 |
-
else:
|
79 |
-
assert (exists(c.io_config) and exists(c.stack_config) and exists(c.quantizer_config)), f'hmm {c}'
|
80 |
-
|
81 |
-
self.io = c.io_config()
|
82 |
-
self.stack = c.stack_config()
|
83 |
-
|
84 |
-
self.plex_layer = c.stack_config.layers//2
|
85 |
-
self.plex_roll = c.plex_roll
|
86 |
-
self.plex_dim = c.quantizer_config.dim
|
87 |
-
|
88 |
-
assert self.plex_dim is not None and c.stack_config.dim is not None, f'One of the following are None: self.plex_dim: {self.plex_dim}, c.stack_config.dim: {c.stack_config.dim}'
|
89 |
-
self.plex_projection = nn.Linear(self.plex_dim, c.stack_config.dim)
|
90 |
-
self.out_norm = Norm(c.stack_config.dim)
|
91 |
-
|
92 |
-
if c.split:
|
93 |
-
self.io2 = c.io_config()
|
94 |
-
self.plex_projection2 = nn.Linear(self.plex_dim, c.stack_config.dim)
|
95 |
-
|
96 |
-
self.io2.fc_loc = None
|
97 |
-
self.io2.fc_scale = None
|
98 |
-
self.io2.fc_weight = None
|
99 |
-
|
100 |
-
kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
101 |
-
head_dim = c.stack_config.dim // c.stack_config.n_head
|
102 |
-
self.cache_num_layers = c.stack_config.layers + ((c.stack_config.layers - self.plex_layer) if c.split else 0)
|
103 |
-
cache_shape = [self.cache_num_layers, c.stack_config.seq_len, 2, kv_heads, head_dim]
|
104 |
-
self.cache_shape = cache_shape
|
105 |
-
self.cache = [None] * self.cache_num_layers
|
106 |
-
|
107 |
-
if exists(c.from_pretrained):
|
108 |
-
result = self.load_state_dict(checkpoint, strict=False)
|
109 |
-
print0_colored(result, 'yellow')
|
110 |
-
|
111 |
-
self.quantizer = c.quantizer_config().eval()
|
112 |
-
self.quantizer.requires_grad = False
|
113 |
-
|
114 |
-
@T.no_grad()
|
115 |
-
def quantize(self, x):
|
116 |
-
if self.c.split:
|
117 |
-
x1, x2 = x.chunk(2, dim=-1)
|
118 |
-
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
119 |
-
quantized1 = self.quantizer(x1)
|
120 |
-
quantized2 = self.quantizer(x2)
|
121 |
-
return quantized1, quantized2
|
122 |
-
else:
|
123 |
-
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
124 |
-
return self.quantizer(x)
|
125 |
-
|
126 |
-
@T.no_grad()
|
127 |
-
def untokenize(self, token_data):
|
128 |
-
return self.quantizer(None, known_latent=token_data)
|
129 |
-
|
130 |
-
def init_cache(self, bsize, device, dtype, length:int=None):
|
131 |
-
cache_shape = self.cache_shape.copy()
|
132 |
-
cache_shape[1] = length or cache_shape[1]
|
133 |
-
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
134 |
-
|
135 |
-
def deinit_cache(self):
|
136 |
-
self.cache = [None] * self.cache_num_layers
|
137 |
-
|
138 |
-
@T.no_grad()
|
139 |
-
def forward(self, data, next_tokens: Optional[Tuple[T.Tensor, T.Tensor]] = None, temps: Optional[Tuple[float, Tuple[float, float]]] = None):
|
140 |
-
if self.c.split:
|
141 |
-
x1, x2 = data.chunk(2, dim=-1)
|
142 |
-
x = self.io.input(x1) + self.io2.input(x2)
|
143 |
-
else:
|
144 |
-
x = self.io.input(data)
|
145 |
-
|
146 |
-
cache_idx = 0
|
147 |
-
for l, layer in enumerate(self.stack.layers):
|
148 |
-
if l == self.plex_layer:
|
149 |
-
if self.c.split:
|
150 |
-
plex1, plex2 = self.quantize(data)
|
151 |
-
plex1 = T.roll(plex1, -self.c.plex_roll, dims=1)
|
152 |
-
plex2 = T.roll(plex2, -self.c.plex_roll, dims=1)
|
153 |
-
if exists(next_tokens):
|
154 |
-
plex1[:, -1:] = self.untokenize(next_tokens[0])
|
155 |
-
plex2[:, -1:] = self.untokenize(next_tokens[1])
|
156 |
-
x1 = x + self.plex_projection(plex1)
|
157 |
-
x2 = x + self.plex_projection2(plex2)
|
158 |
-
else:
|
159 |
-
plex = self.quantize(data)
|
160 |
-
plex = T.roll(plex, -self.c.plex_roll, dims=1)
|
161 |
-
if exists(next_tokens):
|
162 |
-
plex[:, -1:] = self.untokenize(next_tokens)
|
163 |
-
x = x + self.plex_projection(plex)
|
164 |
-
|
165 |
-
if l < self.plex_layer:
|
166 |
-
x = layer(x, kv=self.cache[l])
|
167 |
-
else:
|
168 |
-
if self.c.split:
|
169 |
-
x1 = layer(x1, kv=self.cache[self.plex_layer + cache_idx])
|
170 |
-
cache_idx += 1
|
171 |
-
x2 = layer(x2, kv=self.cache[self.plex_layer + cache_idx])
|
172 |
-
cache_idx += 1
|
173 |
-
else:
|
174 |
-
x = layer(x, kv=self.cache[l])
|
175 |
-
|
176 |
-
with T.autocast(device_type='cuda', dtype=T.bfloat16):
|
177 |
-
if self.c.split:
|
178 |
-
x1, x2 = self.out_norm(x1), self.out_norm(x2)
|
179 |
-
out1, out2 = self.io.output(x1), self.io.output(x2)
|
180 |
-
else:
|
181 |
-
x = self.out_norm(x)
|
182 |
-
out = self.io.output(x)
|
183 |
-
|
184 |
-
if isnt(temps):
|
185 |
-
if self.c.split:
|
186 |
-
return out1, out2
|
187 |
-
else:
|
188 |
-
return out
|
189 |
-
else:
|
190 |
-
if self.c.split:
|
191 |
-
next_data1 = self.io.temp_sample(out1, temps)[:, -1:, :]
|
192 |
-
next_data2 = self.io2.temp_sample(out2, temps)[:, -1:, :]
|
193 |
-
next_data = T.cat([next_data1, next_data2], dim=-1)
|
194 |
-
return next_data
|
195 |
-
else:
|
196 |
-
next_data = self.io.temp_sample(out, temps)[:, -1:, :]
|
197 |
-
return next_data
|
198 |
-
|
199 |
-
@si_module
|
200 |
-
class HertzDevModel(nn.Module):
|
201 |
-
class Config:
|
202 |
-
dim: int
|
203 |
-
vocab_size: int
|
204 |
-
stack_config: Optional[Stack.Config] = None
|
205 |
-
latent_size: int = 32
|
206 |
-
|
207 |
-
split: bool = True
|
208 |
-
|
209 |
-
quantizer_config: Optional[LatentQuantizer.Config] = None
|
210 |
-
resynthesizer_config: Optional[TransformerVAE.Config] = None
|
211 |
-
|
212 |
-
from_pretrained: Optional[Tuple[str, str]] = None
|
213 |
-
|
214 |
-
def __init__(self, c: Config):
|
215 |
-
super().__init__()
|
216 |
-
|
217 |
-
if exists(c.from_pretrained):
|
218 |
-
checkpoint = load_ckpt(*c.from_pretrained)
|
219 |
-
else:
|
220 |
-
assert (exists(c.stack_config)), f'hmm {c}'
|
221 |
-
|
222 |
-
self.input = nn.Linear(c.latent_size, c.dim)
|
223 |
-
if self.c.split:
|
224 |
-
self.input2 = nn.Linear(c.latent_size, c.dim)
|
225 |
-
|
226 |
-
self.shape_rotator = ShapeRotator(c.stack_config.dim//c.stack_config.n_head, c.stack_config.seq_len, theta=c.stack_config.theta)
|
227 |
-
|
228 |
-
self.layers = nn.ModuleList([
|
229 |
-
PerfBlock(
|
230 |
-
dim=c.stack_config.dim,
|
231 |
-
layer_id=l,
|
232 |
-
n_head=c.stack_config.n_head,
|
233 |
-
kv_heads=c.stack_config.kv_heads,
|
234 |
-
ff_dim=c.stack_config.ff_dim,
|
235 |
-
eps=c.stack_config.eps,
|
236 |
-
shape_rotator=self.shape_rotator,
|
237 |
-
) for l in range(c.stack_config.layers)
|
238 |
-
])
|
239 |
-
|
240 |
-
self.output = GPTOutput(c.dim, c.vocab_size)
|
241 |
-
if self.c.split:
|
242 |
-
self.output2 = GPTOutput(c.dim, c.vocab_size)
|
243 |
-
|
244 |
-
self.cache = [None] * c.stack_config.layers
|
245 |
-
self.kv_heads = c.stack_config.kv_heads or c.stack_config.n_head
|
246 |
-
self.head_dim = c.stack_config.dim // c.stack_config.n_head
|
247 |
-
|
248 |
-
if exists(c.from_pretrained):
|
249 |
-
result = self.load_state_dict(checkpoint, strict=False)
|
250 |
-
print0_colored(result, 'yellow')
|
251 |
-
|
252 |
-
self.resynthesizer = c.resynthesizer_config().eval()
|
253 |
-
self.resynthesizer.requires_grad = False
|
254 |
-
|
255 |
-
self.audio_tokenizer = make_tokenizer(device='cpu')
|
256 |
-
self.audio_cache = None
|
257 |
-
self.audio_latent_cache = None
|
258 |
-
self.use_audio_cache = False
|
259 |
-
|
260 |
-
@T.no_grad()
|
261 |
-
def tokenize(self, audio_data):
|
262 |
-
orig_audio_shape = audio_data.shape
|
263 |
-
if exists(self.audio_cache):
|
264 |
-
audio_data = T.cat([self.audio_cache, audio_data], dim=-1)
|
265 |
-
self.audio_cache = audio_data[..., -(6*16_000):]
|
266 |
-
elif self.use_audio_cache:
|
267 |
-
self.audio_cache = audio_data[..., -(6*16_000):]
|
268 |
-
|
269 |
-
if audio_data.shape[1] == 2:
|
270 |
-
enc_ch1 = self.audio_tokenizer.latent_from_data(audio_data[:, 0:1])
|
271 |
-
enc_ch2 = self.audio_tokenizer.latent_from_data(audio_data[:, 1:2])
|
272 |
-
return T.cat([enc_ch1, enc_ch2], dim=-1)[:, -(orig_audio_shape[-1]//2000):]
|
273 |
-
else:
|
274 |
-
return self.audio_tokenizer.latent_from_data(audio_data)[:, -(orig_audio_shape[-1]//2000):]
|
275 |
-
|
276 |
-
@T.no_grad()
|
277 |
-
def untokenize(self, token_data):
|
278 |
-
if exists(self.audio_latent_cache):
|
279 |
-
token_data = T.cat([self.audio_latent_cache, token_data], dim=1)
|
280 |
-
self.audio_latent_cache = token_data[:, -(6*8):]
|
281 |
-
elif self.use_audio_cache:
|
282 |
-
self.audio_latent_cache = token_data[:, -(6*8):]
|
283 |
-
|
284 |
-
if token_data.shape[-1] == 2*self.c.latent_size:
|
285 |
-
dec_ch1 = self.audio_tokenizer.data_from_latent(token_data[:, :self.c.latent_size])
|
286 |
-
dec_ch2 = self.audio_tokenizer.data_from_latent(token_data[:, self.c.latent_size:])
|
287 |
-
return T.cat([dec_ch1, dec_ch2], dim=1)[..., -(token_data.shape[1]*2000):]
|
288 |
-
else:
|
289 |
-
return self.audio_tokenizer.data_from_latent(token_data)[..., -(token_data.shape[1]*2000):]
|
290 |
-
|
291 |
-
def init_cache(self, bsize, device, dtype, length:int=None):
|
292 |
-
cache_shape = [self.c.stack_config.layers, length or self.c.stack_config.seq_len, 2, self.kv_heads, self.head_dim]
|
293 |
-
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
294 |
-
self.resynthesizer.init_cache(bsize, device, dtype, length)
|
295 |
-
self.use_audio_cache = True
|
296 |
-
|
297 |
-
def deinit_cache(self):
|
298 |
-
self.cache = [None] * len(self.layers)
|
299 |
-
self.resynthesizer.deinit_cache()
|
300 |
-
self.audio_cache = None
|
301 |
-
self.audio_latent_cache = None
|
302 |
-
self.use_audio_cache = False
|
303 |
-
|
304 |
-
@T.no_grad()
|
305 |
-
def forward(self, data):
|
306 |
-
if self.c.split:
|
307 |
-
x1, x2 = data.chunk(2, dim=-1)
|
308 |
-
x = self.input(x1) + self.input2(x2)
|
309 |
-
else:
|
310 |
-
x = self.input(data)
|
311 |
-
|
312 |
-
for l, layer in enumerate(self.layers):
|
313 |
-
x = layer(x, kv=self.cache[l])
|
314 |
-
|
315 |
-
if self.c.split:
|
316 |
-
return self.output(x), self.output2(x)
|
317 |
-
else:
|
318 |
-
return self.output(x)
|
319 |
-
|
320 |
-
@T.no_grad()
|
321 |
-
def next_audio_from_audio(self, audio_data: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
322 |
-
latents_in = self.tokenize(audio_data)
|
323 |
-
next_latents = self.next_latent(latents_in, temps)
|
324 |
-
next_model_latent = next_latents[..., self.c.latent_size:]
|
325 |
-
audio_decoded = self.untokenize(next_model_latent)[..., -2000:]
|
326 |
-
return audio_decoded
|
327 |
-
|
328 |
-
|
329 |
-
@T.no_grad()
|
330 |
-
def next_latent(self, model_input: T.Tensor, temps=(0.8, (0.5, 0.1))):
|
331 |
-
|
332 |
-
if self.c.split:
|
333 |
-
logits1, logits2 = self.forward(model_input)
|
334 |
-
next_logits1 = logits1[:, -1]
|
335 |
-
next_logits2 = logits2[:, -1]
|
336 |
-
next_token1 = F.softmax(next_logits1 / temps[0], dim=-1).multinomial(1)
|
337 |
-
next_token2 = F.softmax(next_logits2 / temps[0], dim=-1).multinomial(1)
|
338 |
-
|
339 |
-
next_input = self.resynthesizer(model_input, next_tokens=(next_token1, next_token2), temps=temps[1])
|
340 |
-
else:
|
341 |
-
logits = self.forward(model_input)
|
342 |
-
next_logits = logits[:, -1]
|
343 |
-
next_token = F.softmax(next_logits / temps[0], dim=-1).multinomial(1)
|
344 |
-
|
345 |
-
next_input = self.resynthesizer(model_input, next_tokens=next_token, temps=temps[1])
|
346 |
-
|
347 |
-
return next_input
|
348 |
-
|
349 |
-
|
350 |
-
@T.no_grad()
|
351 |
-
def completion(self, data: T.Tensor, temps=(0.8, (0.5, 0.1)), gen_len=None, use_cache=True) -> T.Tensor:
|
352 |
-
"""
|
353 |
-
only accepts latent-space data.
|
354 |
-
"""
|
355 |
-
if use_cache:
|
356 |
-
self.init_cache(data.shape[0], data.device, T.bfloat16)
|
357 |
-
|
358 |
-
next_input = generated = data
|
359 |
-
|
360 |
-
target_len = min(data.shape[1] + default(gen_len, data.shape[1]), self.c.stack_config.seq_len)
|
361 |
-
|
362 |
-
for _ in tqdm0(range(data.shape[1], target_len)):
|
363 |
-
model_input = next_input if use_cache else generated
|
364 |
-
|
365 |
-
next_input = self.next_latent(model_input, temps)
|
366 |
-
|
367 |
-
generated = T.cat([generated, next_input], dim=1)
|
368 |
-
|
369 |
-
if use_cache:
|
370 |
-
self.deinit_cache()
|
371 |
-
return generated
|
372 |
-
|
373 |
-
|
374 |
-
|
375 |
-
def get_hertz_dev_config(is_split=True):
|
376 |
-
if is_split:
|
377 |
-
checkpoints = [('inference_care_50000', 'e4ff4fe5c7e9f066410d2a5673b7a935'), ('inference_scion_54000', 'cb8bc484423922747b277ebc2933af5d')]
|
378 |
-
else:
|
379 |
-
checkpoints = [('inference_whip_72000', '5e7cee7316900737d55fc5d44cc7a8f7'), ('inference_caraway_112000', 'fcb8368ef8ebf7712f3e31e6856da580')]
|
380 |
-
|
381 |
-
quantizer_config=LatentQuantizer.Config(
|
382 |
-
from_pretrained=('inference_volcano_3', 'd42bf674022c5f84b051d5d7794f6169'),
|
383 |
-
compressor_config=FSQ.Config(
|
384 |
-
levels=[8,8,8,8,8],
|
385 |
-
dim=2048,
|
386 |
-
num_codebooks=1,
|
387 |
-
keep_num_codebooks_dim=None,
|
388 |
-
scale=None,
|
389 |
-
allowed_dtypes=['float32', 'float64', 'bfloat16'],
|
390 |
-
channel_first=False,
|
391 |
-
projection_has_bias=True,
|
392 |
-
return_indices=True,
|
393 |
-
force_quantization_f32=True,
|
394 |
-
use_rms=False
|
395 |
-
),
|
396 |
-
dim=2048,
|
397 |
-
ff_dim=8192,
|
398 |
-
input_dim=32
|
399 |
-
)
|
400 |
-
|
401 |
-
resynthesizer_config=TransformerVAE.Config(
|
402 |
-
io_config=GaussianMixtureIOLayer.Config(
|
403 |
-
latent_dim=32,
|
404 |
-
dim=4096,
|
405 |
-
num_components=8,
|
406 |
-
),
|
407 |
-
stack_config=Stack.Config(
|
408 |
-
layers=8,
|
409 |
-
dim=4096,
|
410 |
-
seq_len=8192,
|
411 |
-
n_head=16,
|
412 |
-
ff_dim=11008,
|
413 |
-
kv_heads=16,
|
414 |
-
eps=1e-5,
|
415 |
-
theta=10_000
|
416 |
-
),
|
417 |
-
quantizer_config=quantizer_config,
|
418 |
-
plex_layer=None,
|
419 |
-
plex_roll=1,
|
420 |
-
split=is_split,
|
421 |
-
from_pretrained=checkpoints[0],
|
422 |
-
)
|
423 |
-
|
424 |
-
return HertzDevModel.Config(
|
425 |
-
dim=4096,
|
426 |
-
vocab_size=32_768,
|
427 |
-
stack_config=Stack.Config(
|
428 |
-
layers=32,
|
429 |
-
dim=4096,
|
430 |
-
seq_len=2048,
|
431 |
-
n_head=32,
|
432 |
-
ff_dim=None,
|
433 |
-
kv_heads=None,
|
434 |
-
eps=1e-5,
|
435 |
-
theta=10_000,
|
436 |
-
),
|
437 |
-
quantizer_config=quantizer_config,
|
438 |
-
resynthesizer_config=resynthesizer_config,
|
439 |
-
split=is_split,
|
440 |
-
from_pretrained=checkpoints[1],
|
441 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
prompts/bob_duo.wav
DELETED
Binary file (554 kB)
|
|
prompts/bob_mono.wav
DELETED
Binary file (560 kB)
|
|
prompts/countdown_mono.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5f22399fd8039c043758ef527a588f608394f761395a437ffc02077fa0dca517
|
3 |
-
size 1664746
|
|
|
|
|
|
|
|
prompts/toaskanymore.wav
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:5fa50bbee5e0f5a952f1b3d8afa2c051e00530b1c9040fcf45e7803a89ce6881
|
3 |
-
size 1088044
|
|
|
|
|
|
|
|
requirements.txt
DELETED
@@ -1,14 +0,0 @@
|
|
1 |
-
torch
|
2 |
-
torchaudio
|
3 |
-
einops
|
4 |
-
tqdm
|
5 |
-
IPython
|
6 |
-
numpy
|
7 |
-
soundfile
|
8 |
-
websockets
|
9 |
-
requests
|
10 |
-
sounddevice
|
11 |
-
matplotlib
|
12 |
-
fastapi
|
13 |
-
uvicorn
|
14 |
-
argparse
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tokenizer.py
DELETED
@@ -1,581 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
from dataclasses import dataclass
|
3 |
-
from typing import Union, Tuple, Literal
|
4 |
-
|
5 |
-
import torch as T
|
6 |
-
import torch.nn as nn
|
7 |
-
from torch.nn.utils.parametrizations import weight_norm
|
8 |
-
|
9 |
-
from utils import load_ckpt
|
10 |
-
from utils.interp import print_colored
|
11 |
-
from utils import si_module, get_activation
|
12 |
-
|
13 |
-
|
14 |
-
|
15 |
-
# Adapted from https://github.com/facebookresearch/AudioDec
|
16 |
-
|
17 |
-
def Conv1d1x1(in_channels, out_channels, bias=True):
|
18 |
-
return nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=bias)
|
19 |
-
|
20 |
-
|
21 |
-
class NonCausalConv1d(nn.Module):
|
22 |
-
"""1D noncausal convloution w/ 2-sides padding."""
|
23 |
-
|
24 |
-
def __init__(
|
25 |
-
self,
|
26 |
-
in_channels,
|
27 |
-
out_channels,
|
28 |
-
kernel_size,
|
29 |
-
stride=1,
|
30 |
-
padding=-1,
|
31 |
-
dilation=1,
|
32 |
-
groups=1,
|
33 |
-
bias=True):
|
34 |
-
super().__init__()
|
35 |
-
self.in_channels = in_channels
|
36 |
-
self.out_channels = out_channels
|
37 |
-
self.kernel_size = kernel_size
|
38 |
-
if padding < 0:
|
39 |
-
padding = (kernel_size - 1) // 2 * dilation
|
40 |
-
self.dilation = dilation
|
41 |
-
self.conv = nn.Conv1d(
|
42 |
-
in_channels=in_channels,
|
43 |
-
out_channels=out_channels,
|
44 |
-
kernel_size=kernel_size,
|
45 |
-
stride=stride,
|
46 |
-
padding=padding,
|
47 |
-
dilation=dilation,
|
48 |
-
groups=groups,
|
49 |
-
bias=bias,
|
50 |
-
)
|
51 |
-
|
52 |
-
def forward(self, x):
|
53 |
-
"""
|
54 |
-
Args:
|
55 |
-
x (Tensor): Float tensor variable with the shape (B, C, T).
|
56 |
-
Returns:
|
57 |
-
Tensor: Float tensor variable with the shape (B, C, T).
|
58 |
-
"""
|
59 |
-
x = self.conv(x)
|
60 |
-
return x
|
61 |
-
|
62 |
-
|
63 |
-
class NonCausalConvTranspose1d(nn.Module):
|
64 |
-
"""1D noncausal transpose convloution."""
|
65 |
-
|
66 |
-
def __init__(
|
67 |
-
self,
|
68 |
-
in_channels,
|
69 |
-
out_channels,
|
70 |
-
kernel_size,
|
71 |
-
stride,
|
72 |
-
padding=-1,
|
73 |
-
output_padding=-1,
|
74 |
-
groups=1,
|
75 |
-
bias=True,
|
76 |
-
):
|
77 |
-
super().__init__()
|
78 |
-
if padding < 0:
|
79 |
-
padding = (stride+1) // 2
|
80 |
-
if output_padding < 0:
|
81 |
-
output_padding = 1 if stride % 2 else 0
|
82 |
-
self.deconv = nn.ConvTranspose1d(
|
83 |
-
in_channels=in_channels,
|
84 |
-
out_channels=out_channels,
|
85 |
-
kernel_size=kernel_size,
|
86 |
-
stride=stride,
|
87 |
-
padding=padding,
|
88 |
-
output_padding=output_padding,
|
89 |
-
groups=groups,
|
90 |
-
bias=bias,
|
91 |
-
)
|
92 |
-
|
93 |
-
def forward(self, x):
|
94 |
-
"""
|
95 |
-
Args:
|
96 |
-
x (Tensor): Float tensor variable with the shape (B, C, T).
|
97 |
-
Returns:
|
98 |
-
Tensor: Float tensor variable with the shape (B, C', T').
|
99 |
-
"""
|
100 |
-
x = self.deconv(x)
|
101 |
-
return x
|
102 |
-
|
103 |
-
|
104 |
-
class CausalConv1d(NonCausalConv1d):
|
105 |
-
def __init__(
|
106 |
-
self,
|
107 |
-
in_channels,
|
108 |
-
out_channels,
|
109 |
-
kernel_size,
|
110 |
-
stride=1,
|
111 |
-
dilation=1,
|
112 |
-
groups=1,
|
113 |
-
bias=True
|
114 |
-
):
|
115 |
-
super(CausalConv1d, self).__init__(
|
116 |
-
in_channels=in_channels,
|
117 |
-
out_channels=out_channels,
|
118 |
-
kernel_size=kernel_size,
|
119 |
-
stride=stride,
|
120 |
-
padding=0,
|
121 |
-
dilation=dilation,
|
122 |
-
groups=groups,
|
123 |
-
bias=bias,
|
124 |
-
)
|
125 |
-
self.stride = stride
|
126 |
-
self.pad_length = (kernel_size - 1) * dilation
|
127 |
-
def forward(self, x):
|
128 |
-
pad = nn.ConstantPad1d((self.pad_length, 0), 0.0)
|
129 |
-
x = pad(x)
|
130 |
-
return self.conv(x)
|
131 |
-
|
132 |
-
|
133 |
-
class CausalConvTranspose1d(NonCausalConvTranspose1d):
|
134 |
-
def __init__(
|
135 |
-
self,
|
136 |
-
in_channels,
|
137 |
-
out_channels,
|
138 |
-
kernel_size,
|
139 |
-
stride,
|
140 |
-
bias=True,
|
141 |
-
pad_buffer=None,
|
142 |
-
):
|
143 |
-
super(CausalConvTranspose1d, self).__init__(
|
144 |
-
in_channels=in_channels,
|
145 |
-
out_channels=out_channels,
|
146 |
-
kernel_size=kernel_size,
|
147 |
-
stride=stride,
|
148 |
-
padding=0,
|
149 |
-
output_padding=0,
|
150 |
-
bias=bias,
|
151 |
-
)
|
152 |
-
self.stride = stride
|
153 |
-
self.pad_length = (math.ceil(kernel_size/stride) - 1)
|
154 |
-
if pad_buffer is None:
|
155 |
-
pad_buffer = T.zeros(1, in_channels, self.pad_length)
|
156 |
-
self.register_buffer("pad_buffer", pad_buffer)
|
157 |
-
|
158 |
-
def forward(self, x):
|
159 |
-
pad = nn.ReplicationPad1d((self.pad_length, 0))
|
160 |
-
x = pad(x)
|
161 |
-
return self.deconv(x)[:, :, self.stride : -self.stride]
|
162 |
-
|
163 |
-
def inference(self, x):
|
164 |
-
x = T.cat((self.pad_buffer, x), -1)
|
165 |
-
self.pad_buffer = x[:, :, -self.pad_length:]
|
166 |
-
return self.deconv(x)[:, :, self.stride : -self.stride]
|
167 |
-
|
168 |
-
def reset_buffer(self):
|
169 |
-
self.pad_buffer.zero_()
|
170 |
-
|
171 |
-
|
172 |
-
class NonCausalResUnit(nn.Module):
|
173 |
-
def __init__(
|
174 |
-
self,
|
175 |
-
in_channels,
|
176 |
-
out_channels,
|
177 |
-
kernel_size=7,
|
178 |
-
dilation=1,
|
179 |
-
bias=False,
|
180 |
-
):
|
181 |
-
super().__init__()
|
182 |
-
self.activation = nn.ELU()
|
183 |
-
self.conv1 = NonCausalConv1d(
|
184 |
-
in_channels=in_channels,
|
185 |
-
out_channels=out_channels,
|
186 |
-
kernel_size=kernel_size,
|
187 |
-
stride=1,
|
188 |
-
dilation=dilation,
|
189 |
-
bias=bias,
|
190 |
-
)
|
191 |
-
self.conv2 = Conv1d1x1(out_channels, out_channels, bias)
|
192 |
-
|
193 |
-
def forward(self, x):
|
194 |
-
y = self.conv1(self.activation(x))
|
195 |
-
y = self.conv2(self.activation(y))
|
196 |
-
return x + y
|
197 |
-
|
198 |
-
|
199 |
-
class CausalResUnit(NonCausalResUnit):
|
200 |
-
def __init__(
|
201 |
-
self,
|
202 |
-
in_channels,
|
203 |
-
out_channels,
|
204 |
-
kernel_size=7,
|
205 |
-
dilation=1,
|
206 |
-
bias=False,
|
207 |
-
):
|
208 |
-
super(CausalResUnit, self).__init__(
|
209 |
-
in_channels=in_channels,
|
210 |
-
out_channels=out_channels,
|
211 |
-
kernel_size=kernel_size,
|
212 |
-
dilation=dilation,
|
213 |
-
bias=bias,
|
214 |
-
)
|
215 |
-
self.conv1 = CausalConv1d(
|
216 |
-
in_channels=in_channels,
|
217 |
-
out_channels=out_channels,
|
218 |
-
kernel_size=kernel_size,
|
219 |
-
stride=1,
|
220 |
-
dilation=dilation,
|
221 |
-
bias=bias,
|
222 |
-
)
|
223 |
-
|
224 |
-
def inference(self, x):
|
225 |
-
y = self.conv1.inference(self.activation(x))
|
226 |
-
y = self.conv2(self.activation(y))
|
227 |
-
return x + y
|
228 |
-
|
229 |
-
|
230 |
-
class ResNetBlock(nn.Module):
|
231 |
-
def __init__(self,
|
232 |
-
in_channels,
|
233 |
-
out_channels,
|
234 |
-
stride,
|
235 |
-
kernel_size=7,
|
236 |
-
dilations=(1, 3, 9),
|
237 |
-
bias=True,
|
238 |
-
mode='encoder',
|
239 |
-
):
|
240 |
-
super().__init__()
|
241 |
-
assert mode in ('encoder', 'decoder'), f"Mode ({mode}) is not supported!"
|
242 |
-
|
243 |
-
self.mode = mode
|
244 |
-
self.stride = stride
|
245 |
-
|
246 |
-
ConvUnit = CausalConv1d if mode == 'encoder' else CausalConvTranspose1d
|
247 |
-
|
248 |
-
res_channels = in_channels if mode == 'encoder' else out_channels
|
249 |
-
|
250 |
-
res_units = [CausalResUnit(
|
251 |
-
res_channels,
|
252 |
-
res_channels,
|
253 |
-
kernel_size=kernel_size,
|
254 |
-
dilation=dilation,
|
255 |
-
) for dilation in dilations]
|
256 |
-
|
257 |
-
if in_channels == out_channels:
|
258 |
-
if mode == 'encoder':
|
259 |
-
self.pool = nn.AvgPool1d(kernel_size=stride, stride=stride)
|
260 |
-
if mode == 'decoder':
|
261 |
-
self.upsample = nn.Upsample(scale_factor=stride, mode='nearest')
|
262 |
-
conv_unit = nn.Conv1d(
|
263 |
-
in_channels=in_channels,
|
264 |
-
out_channels=out_channels,
|
265 |
-
kernel_size=1,
|
266 |
-
bias=bias,
|
267 |
-
) if in_channels != out_channels else nn.Identity()
|
268 |
-
else:
|
269 |
-
conv_unit = ConvUnit(
|
270 |
-
in_channels=in_channels,
|
271 |
-
out_channels=out_channels,
|
272 |
-
kernel_size=(2 * stride),
|
273 |
-
stride=stride,
|
274 |
-
bias=bias,
|
275 |
-
)
|
276 |
-
|
277 |
-
if mode == 'encoder':
|
278 |
-
if in_channels == out_channels:
|
279 |
-
self.res_block = nn.Sequential(*res_units, self.pool, conv_unit)
|
280 |
-
else:
|
281 |
-
self.res_block = nn.Sequential(*res_units, conv_unit)
|
282 |
-
elif mode == 'decoder':
|
283 |
-
if in_channels == out_channels:
|
284 |
-
self.res_block = nn.Sequential(self.upsample, conv_unit, *res_units)
|
285 |
-
else:
|
286 |
-
self.res_block = nn.Sequential(conv_unit, *res_units)
|
287 |
-
|
288 |
-
def forward(self, x):
|
289 |
-
out = x
|
290 |
-
for unit in self.res_block:
|
291 |
-
out = unit(out)
|
292 |
-
return out
|
293 |
-
|
294 |
-
def inference(self, x):
|
295 |
-
for unit in self.res_block:
|
296 |
-
x = unit.inference(x)
|
297 |
-
return x
|
298 |
-
|
299 |
-
|
300 |
-
|
301 |
-
|
302 |
-
@si_module
|
303 |
-
class ResNetStack(nn.Module):
|
304 |
-
"""
|
305 |
-
ResNet encoder or decoder stack. Channel ratios
|
306 |
-
and strides take the default order of from
|
307 |
-
data/io-layer, to the middle of the model.
|
308 |
-
"""
|
309 |
-
class Config:
|
310 |
-
input_channels: int = 1
|
311 |
-
output_channels: int = 1
|
312 |
-
encode_channels: int = 32
|
313 |
-
decode_channel_multiplier: int = 1
|
314 |
-
latent_dim: int = None
|
315 |
-
kernel_size: int = 7
|
316 |
-
bias: bool = True
|
317 |
-
channel_ratios: Tuple[int, ...] = (2, 4, 8, 16)
|
318 |
-
strides: Tuple[int, ...] = (3, 4, 5, 5)
|
319 |
-
mode: Literal['encoder', 'decoder'] = 'encoder'
|
320 |
-
|
321 |
-
def __init__(self, c: Config):
|
322 |
-
super().__init__()
|
323 |
-
assert c.mode in ('encoder', 'decoder'), f"Mode ({c.mode}) is not supported!"
|
324 |
-
|
325 |
-
self.mode = c.mode
|
326 |
-
|
327 |
-
assert len(c.channel_ratios) == len(c.strides)
|
328 |
-
channel_ratios = (1,) + c.channel_ratios
|
329 |
-
strides = c.strides
|
330 |
-
self.middle_channels = c.encode_channels * channel_ratios[-1]
|
331 |
-
if c.mode == 'decoder':
|
332 |
-
channel_ratios = tuple(reversed(channel_ratios))
|
333 |
-
strides = tuple(reversed(strides))
|
334 |
-
|
335 |
-
self.multiplier = c.decode_channel_multiplier if c.mode == 'decoder' else 1
|
336 |
-
res_blocks = [ResNetBlock(
|
337 |
-
c.encode_channels * channel_ratios[s_idx] * self.multiplier,
|
338 |
-
c.encode_channels * channel_ratios[s_idx+1] * self.multiplier,
|
339 |
-
stride,
|
340 |
-
kernel_size=c.kernel_size,
|
341 |
-
bias=c.bias,
|
342 |
-
mode=c.mode,
|
343 |
-
) for s_idx, stride in enumerate(strides)]
|
344 |
-
|
345 |
-
data_conv = CausalConv1d(
|
346 |
-
in_channels=c.input_channels if c.mode == 'encoder' else c.encode_channels * self.multiplier,
|
347 |
-
out_channels=c.encode_channels if c.mode == 'encoder' else c.output_channels,
|
348 |
-
kernel_size=c.kernel_size,
|
349 |
-
stride=1,
|
350 |
-
bias=False,
|
351 |
-
)
|
352 |
-
|
353 |
-
if c.mode == 'encoder':
|
354 |
-
self.res_stack = nn.Sequential(data_conv, *res_blocks)
|
355 |
-
elif c.mode == 'decoder':
|
356 |
-
self.res_stack = nn.Sequential(*res_blocks, data_conv)
|
357 |
-
|
358 |
-
if c.latent_dim is not None:
|
359 |
-
self.latent_proj = Conv1d1x1(self.middle_channels, c.latent_dim, bias=c.bias) if c.mode == 'encoder' else Conv1d1x1(c.latent_dim, self.middle_channels, bias=c.bias)
|
360 |
-
if self.multiplier != 1:
|
361 |
-
self.multiplier_proj = Conv1d1x1(self.middle_channels, self.middle_channels * self.multiplier, bias=c.bias)
|
362 |
-
|
363 |
-
def forward(self, x, return_feats=False):
|
364 |
-
if self.c.latent_dim is not None and self.mode == 'decoder':
|
365 |
-
x = self.latent_proj(x)
|
366 |
-
if self.multiplier != 1:
|
367 |
-
x = self.multiplier_proj(x)
|
368 |
-
|
369 |
-
feats = []
|
370 |
-
for block in self.res_stack:
|
371 |
-
x = block(x)
|
372 |
-
if return_feats:
|
373 |
-
feats.append(x)
|
374 |
-
if self.c.latent_dim is not None and self.mode == 'encoder':
|
375 |
-
x = self.latent_proj(x)
|
376 |
-
if return_feats:
|
377 |
-
feats.append(x)
|
378 |
-
if return_feats:
|
379 |
-
return feats
|
380 |
-
return x
|
381 |
-
|
382 |
-
def inference(self, x):
|
383 |
-
for block in self.res_stack:
|
384 |
-
x = block.inference(x)
|
385 |
-
return x
|
386 |
-
|
387 |
-
def reset_buffer(self):
|
388 |
-
def _reset_buffer(m):
|
389 |
-
if isinstance(m, CausalConv1d) or isinstance(m, CausalConvTranspose1d):
|
390 |
-
m.reset_buffer()
|
391 |
-
self.apply(_reset_buffer)
|
392 |
-
|
393 |
-
def reset_parameters(self):
|
394 |
-
def _reset_parameters(m):
|
395 |
-
if isinstance(m, (nn.Conv1d, nn.ConvTranspose1d)):
|
396 |
-
m.weight.data.normal_(0.0, 0.01)
|
397 |
-
|
398 |
-
self.apply(_reset_parameters)
|
399 |
-
|
400 |
-
|
401 |
-
def apply_weight_norm(self):
|
402 |
-
def _apply_weight_norm(m):
|
403 |
-
if isinstance(m, nn.Conv1d) or isinstance(
|
404 |
-
m, nn.ConvTranspose1d
|
405 |
-
):
|
406 |
-
nn.utils.parametrizations.weight_norm(m)
|
407 |
-
|
408 |
-
self.apply(_apply_weight_norm)
|
409 |
-
|
410 |
-
|
411 |
-
def remove_weight_norm(self):
|
412 |
-
def _remove_weight_norm(m):
|
413 |
-
try:
|
414 |
-
print(m)
|
415 |
-
nn.utils.remove_weight_norm(m)
|
416 |
-
except ValueError: # this module didn't have weight norm
|
417 |
-
return
|
418 |
-
|
419 |
-
self.apply(_remove_weight_norm)
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
@si_module
|
424 |
-
class GaussianZ(nn.Module):
|
425 |
-
class Config:
|
426 |
-
dim: int
|
427 |
-
latent_dim: int
|
428 |
-
bias: bool = False
|
429 |
-
use_weight_norm: bool = False
|
430 |
-
|
431 |
-
def __init__(self, c: Config):
|
432 |
-
super().__init__()
|
433 |
-
|
434 |
-
self.proj_in = nn.Linear(c.dim, c.latent_dim * 2, bias=c.bias)
|
435 |
-
self.proj_out = nn.Linear(c.latent_dim, c.dim, bias=c.bias)
|
436 |
-
|
437 |
-
if c.use_weight_norm:
|
438 |
-
self.proj_in = weight_norm(self.proj_in)
|
439 |
-
self.proj_out = weight_norm(self.proj_out)
|
440 |
-
|
441 |
-
def reparam(self, mu, logvar):
|
442 |
-
std = T.exp(logvar / 2)
|
443 |
-
eps = T.randn_like(std)
|
444 |
-
return mu + eps * std
|
445 |
-
|
446 |
-
def kl_divergence(self, mu, logvar):
|
447 |
-
return T.mean(-0.5 * T.sum(
|
448 |
-
1 + logvar - mu.pow(2) - logvar.exp(),
|
449 |
-
dim=(1, 2))
|
450 |
-
)
|
451 |
-
|
452 |
-
def repr_from_latent(self, latent: Union[dict, T.Tensor]):
|
453 |
-
if isinstance(latent, T.Tensor):
|
454 |
-
z = latent
|
455 |
-
else:
|
456 |
-
z = self.reparam(latent['mu'], latent['logvar'])
|
457 |
-
l = self.proj_out(z)
|
458 |
-
return l
|
459 |
-
|
460 |
-
def forward(self, x: T.Tensor) -> Tuple[T.Tensor, dict]:
|
461 |
-
mu, logvar = self.proj_in(x).chunk(2, dim=-1)
|
462 |
-
kl_div = self.kl_divergence(mu, logvar)
|
463 |
-
z = self.reparam(mu, logvar)
|
464 |
-
xhat = self.proj_out(z)
|
465 |
-
latent = {'mu': mu, 'logvar': logvar, 'z': z, 'kl_divergence': kl_div}
|
466 |
-
return xhat, latent
|
467 |
-
|
468 |
-
|
469 |
-
|
470 |
-
@si_module
|
471 |
-
class WaveCodec(nn.Module):
|
472 |
-
class Config:
|
473 |
-
resnet_config: ResNetStack.Config = None
|
474 |
-
sample_rate: int = 16_000
|
475 |
-
use_weight_norm: bool = False
|
476 |
-
|
477 |
-
compressor_config: dataclass = None
|
478 |
-
|
479 |
-
norm_stddev: float = 1.0
|
480 |
-
|
481 |
-
def __init__(self, c: Config):
|
482 |
-
super().__init__()
|
483 |
-
self.norm_stddev = c.norm_stddev
|
484 |
-
self.encoder = c.resnet_config(mode='encoder')
|
485 |
-
self.sample_rate = c.sample_rate
|
486 |
-
|
487 |
-
self.total_stride = 1
|
488 |
-
for stride in c.resnet_config.strides:
|
489 |
-
self.total_stride *= stride
|
490 |
-
self.tokens_per_second = self.sample_rate / self.total_stride
|
491 |
-
|
492 |
-
self.compressor = c.compressor_config(dim=self.encoder.middle_channels)
|
493 |
-
|
494 |
-
self.decoder = c.resnet_config(mode='decoder')
|
495 |
-
|
496 |
-
if c.use_weight_norm:
|
497 |
-
self.encoder.apply_weight_norm()
|
498 |
-
self.decoder.apply_weight_norm()
|
499 |
-
self.encoder.reset_parameters()
|
500 |
-
self.decoder.reset_parameters()
|
501 |
-
|
502 |
-
def encode(self, data):
|
503 |
-
return self.encoder(data/self.norm_stddev)
|
504 |
-
|
505 |
-
def decode(self, latent):
|
506 |
-
return self.decoder(latent.transpose(1, 2))*self.norm_stddev
|
507 |
-
|
508 |
-
@T.no_grad()
|
509 |
-
def latent_from_data(self, data, get_parameters=False):
|
510 |
-
x = self.encode(data)
|
511 |
-
l_in = x.transpose(1, 2)
|
512 |
-
l, latent = self.compressor(l_in)
|
513 |
-
return latent['z'] if not get_parameters else {
|
514 |
-
'mu': latent['mu'],
|
515 |
-
'logvar': latent['logvar'],
|
516 |
-
'z': latent['z'],
|
517 |
-
}
|
518 |
-
|
519 |
-
@T.no_grad()
|
520 |
-
def data_from_latent(self, latent):
|
521 |
-
l = self.compressor.repr_from_latent(latent)
|
522 |
-
x = self.decode(l)
|
523 |
-
return x
|
524 |
-
|
525 |
-
def process(self, x):
|
526 |
-
return self.latent_from_data(x)
|
527 |
-
|
528 |
-
def unprocess(self, latent):
|
529 |
-
return self.data_from_latent(latent)
|
530 |
-
|
531 |
-
def forward(self, audio_input):
|
532 |
-
x = self.encode(audio_input)
|
533 |
-
|
534 |
-
l_in = x.transpose(1, 2)
|
535 |
-
l, latent = self.compressor(l_in)
|
536 |
-
|
537 |
-
xhat = self.decode(l)
|
538 |
-
return xhat, latent
|
539 |
-
|
540 |
-
|
541 |
-
|
542 |
-
def make_tokenizer(device='cuda'):
|
543 |
-
generator_config = WaveCodec.Config(
|
544 |
-
resnet_config=ResNetStack.Config(
|
545 |
-
input_channels=1,
|
546 |
-
output_channels=1,
|
547 |
-
encode_channels=16,
|
548 |
-
decode_channel_multiplier=4,
|
549 |
-
kernel_size=7,
|
550 |
-
bias=True,
|
551 |
-
channel_ratios=(4, 8, 16, 16, 16, 16),
|
552 |
-
strides=(2, 2, 4, 5, 5, 5),
|
553 |
-
mode=None,
|
554 |
-
),
|
555 |
-
use_weight_norm=True,
|
556 |
-
|
557 |
-
compressor_config=GaussianZ.Config(
|
558 |
-
dim=None,
|
559 |
-
latent_dim=32,
|
560 |
-
|
561 |
-
bias=True,
|
562 |
-
use_weight_norm=True
|
563 |
-
),
|
564 |
-
|
565 |
-
norm_stddev=0.05,
|
566 |
-
)
|
567 |
-
checkpoint = load_ckpt("inference_apatosaurus_95000", expected_hash="ba876edb97b988e9196e449dd176ca97")
|
568 |
-
|
569 |
-
tokenizer = generator_config()
|
570 |
-
|
571 |
-
load_result = tokenizer.load_state_dict(checkpoint, strict=False)
|
572 |
-
print_colored(f"Loaded tokenizer state dict: {load_result}", "grey")
|
573 |
-
|
574 |
-
tokenizer = tokenizer.eval()
|
575 |
-
# Only convert to bfloat16 if using CUDA
|
576 |
-
if device == 'cuda':
|
577 |
-
tokenizer = tokenizer.bfloat16()
|
578 |
-
tokenizer = tokenizer.to(device)
|
579 |
-
tokenizer.requires_grad_ = False
|
580 |
-
return tokenizer
|
581 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
transformer.py
DELETED
@@ -1,382 +0,0 @@
|
|
1 |
-
from typing import Optional, Tuple, MutableMapping
|
2 |
-
from typing import Union
|
3 |
-
import math
|
4 |
-
from contextlib import nullcontext
|
5 |
-
|
6 |
-
import torch
|
7 |
-
import torch as T
|
8 |
-
import torch.nn as nn
|
9 |
-
import torch.nn.functional as F
|
10 |
-
from torch import Tensor
|
11 |
-
from torch.nn.attention import SDPBackend
|
12 |
-
|
13 |
-
from einops import rearrange
|
14 |
-
|
15 |
-
from utils import si_module, default, exists, load_ckpt
|
16 |
-
|
17 |
-
CACHE_FILL_VALUE = -1
|
18 |
-
|
19 |
-
def get_cache_len(cache: Optional[Tensor]) -> int:
|
20 |
-
"""
|
21 |
-
cache: (batch, seq_len, 2, kv_heads, head_dim)
|
22 |
-
"""
|
23 |
-
if cache is None:
|
24 |
-
return 0
|
25 |
-
nonzeros = T.any(cache.flatten(2) != CACHE_FILL_VALUE, dim=-1)
|
26 |
-
length = nonzeros.sum(dim=-1).int()
|
27 |
-
assert T.all(length == length[0])
|
28 |
-
return length[0]
|
29 |
-
|
30 |
-
|
31 |
-
def rotate_half(x):
|
32 |
-
x1, x2 = x.chunk(2, dim=-1)
|
33 |
-
return torch.cat((-x2, x1), dim=-1)
|
34 |
-
|
35 |
-
|
36 |
-
def apply_rotary_pos_emb(x, cos, sin, offset: int = 0):
|
37 |
-
assert (
|
38 |
-
cos.shape[1] >= offset + x.shape[1]
|
39 |
-
), f"Offset and/or input sequence is too large,\
|
40 |
-
\n offset: {offset}, seq_len: {x.shape[1]}, max: {cos.shape[1]}"
|
41 |
-
|
42 |
-
cos_out = cos[:, offset : offset + x.shape[1], :, :]
|
43 |
-
sin_out = sin[:, offset : offset + x.shape[1], :, :]
|
44 |
-
|
45 |
-
return (x * cos_out) + (rotate_half(x) * sin_out)
|
46 |
-
|
47 |
-
|
48 |
-
# Adapted from https://github.com/foundation-model-stack/foundation-model-stack
|
49 |
-
class ShapeRotator:
|
50 |
-
def __init__(
|
51 |
-
self,
|
52 |
-
dim: int,
|
53 |
-
end: int,
|
54 |
-
theta: float = 10_000,
|
55 |
-
):
|
56 |
-
super().__init__()
|
57 |
-
self.dim = dim
|
58 |
-
self.ratio = theta
|
59 |
-
self.cached_freqs: MutableMapping[int, MutableMapping[int, torch.Tensor]] = {}
|
60 |
-
self.max_seq_len_cached: MutableMapping[int, int] = {}
|
61 |
-
self.ntk_scaling = False
|
62 |
-
self.max_seq_len = end
|
63 |
-
|
64 |
-
def compute_freqs_cis(self, device, max_seq_len=None):
|
65 |
-
alpha = 1
|
66 |
-
dev_idx = device.index
|
67 |
-
max_seq_len = default(max_seq_len, self.max_seq_len)
|
68 |
-
|
69 |
-
if dev_idx not in self.cached_freqs:
|
70 |
-
self.cached_freqs[dev_idx] = {}
|
71 |
-
if dev_idx not in self.max_seq_len_cached:
|
72 |
-
self.max_seq_len_cached[dev_idx] = 0
|
73 |
-
|
74 |
-
|
75 |
-
if self.max_seq_len_cached[dev_idx] > 0:
|
76 |
-
return 1
|
77 |
-
max_seq_len = max(max_seq_len, self.max_seq_len)
|
78 |
-
|
79 |
-
if (
|
80 |
-
1 in self.cached_freqs[dev_idx]
|
81 |
-
and max_seq_len <= self.max_seq_len_cached[dev_idx]
|
82 |
-
):
|
83 |
-
return 1
|
84 |
-
|
85 |
-
ratio = self.ratio
|
86 |
-
dim = self.dim
|
87 |
-
|
88 |
-
freqs = 1.0 / (ratio ** (torch.arange(0, dim, 2, device=device).float() / dim))
|
89 |
-
|
90 |
-
t = torch.arange(max_seq_len, device=device, dtype=freqs.dtype)
|
91 |
-
freqs = torch.einsum("i,j->ij", t, freqs)
|
92 |
-
emb = torch.cat((freqs, freqs), dim=-1).to(device)
|
93 |
-
|
94 |
-
cos_to_cache = emb.cos()[None, :, None, :]
|
95 |
-
sin_to_cache = emb.sin()[None, :, None, :]
|
96 |
-
|
97 |
-
self.max_seq_len_cached[dev_idx] = max_seq_len
|
98 |
-
|
99 |
-
self.cached_freqs[dev_idx][alpha] = torch.stack(
|
100 |
-
[
|
101 |
-
cos_to_cache,
|
102 |
-
sin_to_cache,
|
103 |
-
],
|
104 |
-
dim=-1,
|
105 |
-
)
|
106 |
-
|
107 |
-
return alpha
|
108 |
-
|
109 |
-
def rotate(
|
110 |
-
self,
|
111 |
-
q: Tensor,
|
112 |
-
k: Tensor,
|
113 |
-
offset: int = 0,
|
114 |
-
) -> Tuple[Tensor, Tensor]:
|
115 |
-
"""
|
116 |
-
Args
|
117 |
-
----
|
118 |
-
q : torch.Tensor
|
119 |
-
Embedded query tensor, expected size is B x S x H x Eh
|
120 |
-
k : torch.Tensor
|
121 |
-
Embedded query tensor, expected size is B x S x H x Eh
|
122 |
-
"""
|
123 |
-
assert len(q.size()) == 4
|
124 |
-
assert len(k.size()) == 4
|
125 |
-
|
126 |
-
seq_len = self.max_seq_len
|
127 |
-
alpha = self.compute_freqs_cis(q.device, seq_len)
|
128 |
-
freqs = self.cached_freqs[q.device.index][alpha]
|
129 |
-
|
130 |
-
freqs = freqs.float() # 1 L D/2 2 2
|
131 |
-
q_out = apply_rotary_pos_emb(q, freqs[..., 0], freqs[..., 1], offset=offset).type_as(q)
|
132 |
-
k_out = apply_rotary_pos_emb(k, freqs[..., 0], freqs[..., 1], offset=offset).type_as(k)
|
133 |
-
|
134 |
-
return q_out.view_as(q), k_out.view_as(k)
|
135 |
-
|
136 |
-
class Linear(nn.Linear):
|
137 |
-
def __init__(self, *args, **kwargs):
|
138 |
-
super().__init__(*args, **kwargs, bias=False)
|
139 |
-
|
140 |
-
class Norm(nn.Module):
|
141 |
-
def __init__(self,
|
142 |
-
dim: int,
|
143 |
-
eps: float = 1e-5,) -> None:
|
144 |
-
super().__init__()
|
145 |
-
self.eps = eps
|
146 |
-
self.weight = nn.Parameter(T.ones((dim,)))
|
147 |
-
|
148 |
-
def forward(self, input: Tensor) -> Tensor:
|
149 |
-
return F.layer_norm(input, (self.weight.shape[0],), weight=self.weight, bias=None, eps=self.eps)
|
150 |
-
|
151 |
-
|
152 |
-
class FFNN(nn.Module):
|
153 |
-
def __init__(self,
|
154 |
-
dim: int,
|
155 |
-
expand_dim: int = None,):
|
156 |
-
super().__init__()
|
157 |
-
expand_dim = default(expand_dim, 256 * ((int(2 * 4 * dim / 3) + 256 - 1) // 256))
|
158 |
-
self.dim = dim
|
159 |
-
self.expand_dim = expand_dim
|
160 |
-
|
161 |
-
self.gateup_proj = Linear(dim, 2*expand_dim)
|
162 |
-
self.down_proj = Linear(expand_dim, dim)
|
163 |
-
|
164 |
-
def forward(self, x):
|
165 |
-
gate, up = self.gateup_proj(x).chunk(2, dim=-1)
|
166 |
-
return self.down_proj(up * F.silu(gate))
|
167 |
-
|
168 |
-
class GQA(nn.Module):
|
169 |
-
def __init__(self,
|
170 |
-
dim: int,
|
171 |
-
n_head: int,
|
172 |
-
shape_rotator: ShapeRotator,
|
173 |
-
kv_heads: Optional[int] = None,
|
174 |
-
eps: float = 1e-5,
|
175 |
-
causal: bool = True,):
|
176 |
-
super().__init__()
|
177 |
-
self.n_heads = n_head
|
178 |
-
self.kv_heads = default(kv_heads, n_head)
|
179 |
-
self.head_dim = dim // n_head
|
180 |
-
self.causal = causal
|
181 |
-
|
182 |
-
self.proj_qkv = Linear(dim, self.head_dim*(n_head+2*self.kv_heads))
|
183 |
-
|
184 |
-
self.norm_q = Norm(self.head_dim*n_head, eps=eps)
|
185 |
-
self.norm_k = Norm(self.head_dim*self.kv_heads, eps=eps)
|
186 |
-
|
187 |
-
self.attn_out = Linear(dim, dim)
|
188 |
-
|
189 |
-
self.shape_rotator = shape_rotator
|
190 |
-
|
191 |
-
def _sdpa(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
192 |
-
k = k.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
193 |
-
v = v.repeat_interleave(self.n_heads // self.kv_heads, dim=2)
|
194 |
-
with nn.attention.sdpa_kernel(SDPBackend.FLASH_ATTENTION) if k.device.type == 'cuda' else nullcontext():
|
195 |
-
x = F.scaled_dot_product_attention(
|
196 |
-
q.transpose(1, 2),
|
197 |
-
k.transpose(1, 2),
|
198 |
-
v.transpose(1, 2),
|
199 |
-
is_causal=False if (q.size(1) != k.size(1)) else self.causal,
|
200 |
-
)
|
201 |
-
x = x.transpose(1, 2).contiguous()
|
202 |
-
return x
|
203 |
-
|
204 |
-
def _attend(self, q: Tensor, k: Tensor, v: Tensor, kv_cache: Optional[Tensor] = None,):
|
205 |
-
cache_len = get_cache_len(kv_cache)
|
206 |
-
q, k = self.shape_rotator.rotate(q, k, offset=cache_len)
|
207 |
-
if exists(kv_cache):
|
208 |
-
k = T.cat([kv_cache[:, :cache_len, 0], k], dim=1)
|
209 |
-
v = T.cat([kv_cache[:, :cache_len, 1], v], dim=1)
|
210 |
-
kv_cache[:, :k.size(1), 0] = k
|
211 |
-
kv_cache[:, :v.size(1), 1] = v
|
212 |
-
x = self._sdpa(q, k, v)
|
213 |
-
return self.attn_out(rearrange(x, 'b s h d -> b s (h d)'))
|
214 |
-
|
215 |
-
def _project(self, x):
|
216 |
-
full_q, full_k, full_v = self.proj_qkv(x).chunk(3, dim=-1)
|
217 |
-
normed_full_q = self.norm_q(full_q).to(full_q.dtype)
|
218 |
-
normed_full_k = self.norm_k(full_k).to(full_k.dtype)
|
219 |
-
|
220 |
-
q = rearrange(normed_full_q, 'b s (h d) -> b s h d', h=self.n_heads)
|
221 |
-
k = rearrange(normed_full_k, 'b s (h d) -> b s h d', h=self.kv_heads)
|
222 |
-
v = rearrange(full_v, 'b s (h d) -> b s h d', h=self.kv_heads)
|
223 |
-
return q, k, v
|
224 |
-
|
225 |
-
def forward(self,
|
226 |
-
x: Tensor,
|
227 |
-
kv: Optional[Tensor] = None,):
|
228 |
-
"""
|
229 |
-
x: (B, S, D)
|
230 |
-
kv: (B, S, H, D)
|
231 |
-
"""
|
232 |
-
q, k, v = self._project(x)
|
233 |
-
return self._attend(q, k, v, kv_cache=kv)
|
234 |
-
|
235 |
-
|
236 |
-
class PreNormAttn(nn.Module):
|
237 |
-
def __init__(self,
|
238 |
-
dim: int,
|
239 |
-
n_head: int,
|
240 |
-
shape_rotator: ShapeRotator,
|
241 |
-
kv_heads: Optional[int] = None,
|
242 |
-
eps: float = 1e-5,
|
243 |
-
causal: bool = True,):
|
244 |
-
super().__init__()
|
245 |
-
self.attn_norm = Norm(dim, eps=eps)
|
246 |
-
self.attn = GQA(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
247 |
-
|
248 |
-
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
249 |
-
"""
|
250 |
-
x: (B, S, D)
|
251 |
-
kv: (B, S, H, D)
|
252 |
-
"""
|
253 |
-
return x + self.attn(self.attn_norm(x), kv)
|
254 |
-
|
255 |
-
class PreNormFFNN(nn.Module):
|
256 |
-
def __init__(self,
|
257 |
-
dim: int,
|
258 |
-
ff_dim: int,
|
259 |
-
eps: float = 1e-5,):
|
260 |
-
super().__init__()
|
261 |
-
self.ffnn_norm = Norm(dim, eps=eps)
|
262 |
-
self.ffnn = FFNN(dim, ff_dim)
|
263 |
-
|
264 |
-
def forward(self, x: Tensor) -> Tensor:
|
265 |
-
return x + self.ffnn(self.ffnn_norm(x))
|
266 |
-
|
267 |
-
class Block(nn.Module):
|
268 |
-
def __init__(self,
|
269 |
-
dim: int,
|
270 |
-
layer_id: int = 0,
|
271 |
-
n_head: int = 16,
|
272 |
-
kv_heads: Optional[int] = None,
|
273 |
-
ff_dim: Optional[int] = None,
|
274 |
-
eps: float = 1e-5,
|
275 |
-
causal: bool = True,
|
276 |
-
shape_rotator: ShapeRotator = None):
|
277 |
-
super().__init__()
|
278 |
-
self.attn = PreNormAttn(dim, n_head, shape_rotator, kv_heads, eps=eps, causal=causal)
|
279 |
-
self.ffnn = PreNormFFNN(dim, ff_dim, eps=eps)
|
280 |
-
self.dim = dim
|
281 |
-
self.layer_id = layer_id
|
282 |
-
self.head_dim = dim // n_head
|
283 |
-
self.expand_dim = self.ffnn.ffnn.expand_dim
|
284 |
-
|
285 |
-
self.reset_parameters()
|
286 |
-
|
287 |
-
def reset_parameters(self):
|
288 |
-
std = 1.0 / math.sqrt(self.dim)
|
289 |
-
nn.init.trunc_normal_(self.ffnn.ffnn.gateup_proj.weight, std=std, a=-3 * std, b=3 * std)
|
290 |
-
nn.init.trunc_normal_(self.attn.attn.proj_qkv.weight, std=std, a=-3 * std, b=3 * std)
|
291 |
-
nn.init.trunc_normal_(self.attn.attn.attn_out.weight, std=std, a=-3 * std, b=3 * std)
|
292 |
-
|
293 |
-
xstd = 1.0 / math.sqrt(self.expand_dim)
|
294 |
-
nn.init.trunc_normal_(self.ffnn.ffnn.down_proj.weight, std=xstd, a=-3 * xstd, b=3 * xstd)
|
295 |
-
|
296 |
-
def forward(self, x: Tensor, kv: Optional[Tensor] = None) -> Tensor:
|
297 |
-
"""
|
298 |
-
x: (B, S, D)
|
299 |
-
kv: (B, S, H, D)
|
300 |
-
"""
|
301 |
-
h = self.attn(x, kv)
|
302 |
-
out = self.ffnn(h)
|
303 |
-
return out
|
304 |
-
|
305 |
-
|
306 |
-
|
307 |
-
class GPTOutput(nn.Module):
|
308 |
-
def __init__(self, dim, vocab_size):
|
309 |
-
super().__init__()
|
310 |
-
self.dim = dim
|
311 |
-
self.norm = Norm(dim)
|
312 |
-
self.output = Linear(dim, vocab_size)
|
313 |
-
|
314 |
-
self.reset_parameters()
|
315 |
-
|
316 |
-
def reset_parameters(self):
|
317 |
-
std = 1.0 / math.sqrt(self.dim**2)
|
318 |
-
nn.init.trunc_normal_(self.output.weight, std=std, a=-3 * std, b=3 * std)
|
319 |
-
|
320 |
-
def forward(self, x):
|
321 |
-
return self.output(self.norm(x))
|
322 |
-
|
323 |
-
@si_module
|
324 |
-
class Stack(nn.Module):
|
325 |
-
class Config:
|
326 |
-
layers: int
|
327 |
-
dim: int
|
328 |
-
seq_len: int
|
329 |
-
n_head: int = 32
|
330 |
-
ff_dim: int = None
|
331 |
-
kv_heads: int = None
|
332 |
-
eps: float = 1e-5
|
333 |
-
theta: Union[int, float] = 10_000
|
334 |
-
causal: bool = True
|
335 |
-
|
336 |
-
from_pretrained: Optional[Tuple[str, int]] = None
|
337 |
-
|
338 |
-
def __init__(self, c: Config):
|
339 |
-
super().__init__()
|
340 |
-
|
341 |
-
from_pretrained = c.from_pretrained
|
342 |
-
if exists(from_pretrained):
|
343 |
-
checkpoint = load_ckpt(c.from_pretrained)
|
344 |
-
|
345 |
-
self.shape_rotator = ShapeRotator(c.dim//c.n_head, c.seq_len, theta=c.theta)
|
346 |
-
|
347 |
-
self.layers = nn.ModuleList([
|
348 |
-
Block(
|
349 |
-
dim=c.dim,
|
350 |
-
layer_id=l,
|
351 |
-
n_head=c.n_head,
|
352 |
-
kv_heads=c.kv_heads,
|
353 |
-
ff_dim=c.ff_dim,
|
354 |
-
eps=c.eps,
|
355 |
-
causal=c.causal,
|
356 |
-
shape_rotator=self.shape_rotator,
|
357 |
-
) for l in range(c.layers)
|
358 |
-
])
|
359 |
-
|
360 |
-
kv_heads = c.kv_heads or c.n_head
|
361 |
-
head_dim = c.dim // c.n_head
|
362 |
-
cache_shape = [c.layers, c.seq_len, 2, kv_heads, head_dim]
|
363 |
-
self.cache_shape = cache_shape
|
364 |
-
self.cache = [None] * c.layers
|
365 |
-
|
366 |
-
if exists(from_pretrained):
|
367 |
-
self.load_state_dict(checkpoint)
|
368 |
-
|
369 |
-
def init_cache(self, bsize, device, dtype, length:int=None):
|
370 |
-
if self.cache_shape is None:
|
371 |
-
return
|
372 |
-
cache_shape = self.cache_shape.copy()
|
373 |
-
cache_shape[1] = length or cache_shape[1]
|
374 |
-
self.cache = T.full((bsize, *cache_shape), CACHE_FILL_VALUE, device=device, dtype=dtype).transpose(0, 1)
|
375 |
-
|
376 |
-
def deinit_cache(self):
|
377 |
-
self.cache = [None] * len(self.cache)
|
378 |
-
|
379 |
-
def forward(self, x: Tensor) -> Tensor:
|
380 |
-
for l, layer in enumerate(self.layers):
|
381 |
-
x = layer(x, kv=self.cache[l])
|
382 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .blocks import *
|
2 |
-
from .dist import *
|
3 |
-
from .interp import *
|
|
|
|
|
|
|
|
utils/blocks.py
DELETED
@@ -1,92 +0,0 @@
|
|
1 |
-
from dataclasses import dataclass
|
2 |
-
from typing import TypeVar, Generic, Type, Optional
|
3 |
-
from functools import wraps
|
4 |
-
import time
|
5 |
-
import random
|
6 |
-
|
7 |
-
import torch as T
|
8 |
-
import torch.nn as nn
|
9 |
-
|
10 |
-
# @TODO: remove si_module from codebase
|
11 |
-
# we use this in our research codebase to make modules from callable configs
|
12 |
-
si_module_TpV = TypeVar('si_module_TpV')
|
13 |
-
def si_module(cls: Type[si_module_TpV]) -> Type[si_module_TpV]:
|
14 |
-
if not hasattr(cls, 'Config') or not isinstance(cls.Config, type):
|
15 |
-
class Config:
|
16 |
-
pass
|
17 |
-
cls.Config = Config
|
18 |
-
|
19 |
-
cls.Config = dataclass(cls.Config)
|
20 |
-
|
21 |
-
class ConfigWrapper(cls.Config, Generic[si_module_TpV]):
|
22 |
-
def __call__(self, *args, **kwargs) -> si_module_TpV:
|
23 |
-
if len(kwargs) > 0:
|
24 |
-
config_dict = {field.name: getattr(self, field.name) for field in self.__dataclass_fields__.values()}
|
25 |
-
config_dict.update(kwargs)
|
26 |
-
new_config = type(self)(**config_dict)
|
27 |
-
return cls(new_config)
|
28 |
-
else:
|
29 |
-
return cls(self, *args)
|
30 |
-
|
31 |
-
ConfigWrapper.__module__ = cls.__module__
|
32 |
-
ConfigWrapper.__name__ = f"{cls.__name__}Config"
|
33 |
-
ConfigWrapper.__qualname__ = f"{cls.__qualname__}.Config"
|
34 |
-
|
35 |
-
cls.Config = ConfigWrapper
|
36 |
-
|
37 |
-
original_init = cls.__init__
|
38 |
-
def new_init(self, *args, **kwargs):
|
39 |
-
self.c = next((arg for arg in args if isinstance(arg, cls.Config)), None) or next((arg for arg in kwargs.values() if isinstance(arg, cls.Config)), None)
|
40 |
-
original_init(self, *args, **kwargs)
|
41 |
-
self.register_buffer('_device_tracker', T.Tensor(), persistent=False)
|
42 |
-
|
43 |
-
cls.__init__ = new_init
|
44 |
-
|
45 |
-
@property
|
46 |
-
def device(self):
|
47 |
-
return self._device_tracker.device
|
48 |
-
|
49 |
-
@property
|
50 |
-
def dtype(self):
|
51 |
-
return self._device_tracker.dtype
|
52 |
-
|
53 |
-
cls.device = device
|
54 |
-
cls.dtype = dtype
|
55 |
-
|
56 |
-
return cls
|
57 |
-
|
58 |
-
|
59 |
-
def get_activation(nonlinear_activation, nonlinear_activation_params={}):
|
60 |
-
if hasattr(nn, nonlinear_activation):
|
61 |
-
return getattr(nn, nonlinear_activation)(**nonlinear_activation_params)
|
62 |
-
else:
|
63 |
-
raise NotImplementedError(f"Activation {nonlinear_activation} not found in torch.nn")
|
64 |
-
|
65 |
-
|
66 |
-
def exists(v):
|
67 |
-
return v is not None
|
68 |
-
|
69 |
-
def isnt(v):
|
70 |
-
return not exists(v)
|
71 |
-
|
72 |
-
def truthyexists(v):
|
73 |
-
return exists(v) and v is not False
|
74 |
-
|
75 |
-
def truthyattr(obj, attr):
|
76 |
-
return hasattr(obj, attr) and truthyexists(getattr(obj, attr))
|
77 |
-
|
78 |
-
defaultT = TypeVar('defaultT')
|
79 |
-
|
80 |
-
def default(*args: Optional[defaultT]) -> Optional[defaultT]:
|
81 |
-
for arg in args:
|
82 |
-
if exists(arg):
|
83 |
-
return arg
|
84 |
-
return None
|
85 |
-
|
86 |
-
def maybe(fn):
|
87 |
-
@wraps(fn)
|
88 |
-
def inner(x, *args, **kwargs):
|
89 |
-
if not exists(x):
|
90 |
-
return x
|
91 |
-
return fn(x, *args, **kwargs)
|
92 |
-
return inner
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/dist.py
DELETED
@@ -1,99 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import torch as T
|
3 |
-
import re
|
4 |
-
from tqdm import tqdm
|
5 |
-
from datetime import timedelta
|
6 |
-
|
7 |
-
import requests
|
8 |
-
import hashlib
|
9 |
-
|
10 |
-
from io import BytesIO
|
11 |
-
|
12 |
-
def rank0():
|
13 |
-
rank = os.environ.get('RANK')
|
14 |
-
if rank is None or rank == '0':
|
15 |
-
return True
|
16 |
-
else:
|
17 |
-
return False
|
18 |
-
|
19 |
-
def local0():
|
20 |
-
local_rank = os.environ.get('LOCAL_RANK')
|
21 |
-
if local_rank is None or local_rank == '0':
|
22 |
-
return True
|
23 |
-
else:
|
24 |
-
return False
|
25 |
-
class tqdm0(tqdm):
|
26 |
-
def __init__(self, *args, **kwargs):
|
27 |
-
total = kwargs.get('total', None)
|
28 |
-
if total is None and len(args) > 0:
|
29 |
-
try:
|
30 |
-
total = len(args[0])
|
31 |
-
except TypeError:
|
32 |
-
pass
|
33 |
-
if total is not None:
|
34 |
-
kwargs['miniters'] = max(1, total // 20)
|
35 |
-
super().__init__(*args, **kwargs, disable=not rank0(), bar_format='{bar}| {n_fmt}/{total_fmt} [{rate_fmt}{postfix}]')
|
36 |
-
|
37 |
-
def print0(*args, **kwargs):
|
38 |
-
if rank0():
|
39 |
-
print(*args, **kwargs)
|
40 |
-
|
41 |
-
_PRINTED_IDS = set()
|
42 |
-
|
43 |
-
def printonce(*args, id=None, **kwargs):
|
44 |
-
if id is None:
|
45 |
-
id = ' '.join(map(str, args))
|
46 |
-
|
47 |
-
if id not in _PRINTED_IDS:
|
48 |
-
print(*args, **kwargs)
|
49 |
-
_PRINTED_IDS.add(id)
|
50 |
-
|
51 |
-
def print0once(*args, **kwargs):
|
52 |
-
if rank0():
|
53 |
-
printonce(*args, **kwargs)
|
54 |
-
|
55 |
-
def init_dist():
|
56 |
-
if T.distributed.is_initialized():
|
57 |
-
print0('Distributed already initialized')
|
58 |
-
rank = T.distributed.get_rank()
|
59 |
-
local_rank = int(os.environ.get('LOCAL_RANK', 0))
|
60 |
-
world_size = T.distributed.get_world_size()
|
61 |
-
else:
|
62 |
-
try:
|
63 |
-
rank = int(os.environ['RANK'])
|
64 |
-
local_rank = int(os.environ['LOCAL_RANK'])
|
65 |
-
world_size = int(os.environ['WORLD_SIZE'])
|
66 |
-
device = f'cuda:{local_rank}'
|
67 |
-
T.cuda.set_device(device)
|
68 |
-
T.distributed.init_process_group(backend='nccl', timeout=timedelta(minutes=30), rank=rank, world_size=world_size, device_id=T.device(device))
|
69 |
-
print(f'Rank {rank} of {world_size}.')
|
70 |
-
except Exception as e:
|
71 |
-
print0once(f'Not initializing distributed env: {e}')
|
72 |
-
rank = 0
|
73 |
-
local_rank = 0
|
74 |
-
world_size = 1
|
75 |
-
return rank, local_rank, world_size
|
76 |
-
|
77 |
-
def load_ckpt(load_from_location, expected_hash=None):
|
78 |
-
if local0():
|
79 |
-
os.makedirs('ckpt', exist_ok=True)
|
80 |
-
url = f"https://ckpt.si.inc/hertz-dev/{load_from_location}.pt"
|
81 |
-
save_path = f"ckpt/{load_from_location}.pt"
|
82 |
-
if not os.path.exists(save_path):
|
83 |
-
response = requests.get(url, stream=True)
|
84 |
-
total_size = int(response.headers.get('content-length', 0))
|
85 |
-
with open(save_path, 'wb') as f, tqdm(total=total_size, desc=f'Downloading {load_from_location}.pt', unit='GB', unit_scale=1/(1024*1024*1024)) as pbar:
|
86 |
-
for chunk in response.iter_content(chunk_size=8192):
|
87 |
-
f.write(chunk)
|
88 |
-
pbar.update(len(chunk))
|
89 |
-
if expected_hash is not None:
|
90 |
-
with open(save_path, 'rb') as f:
|
91 |
-
file_hash = hashlib.md5(f.read()).hexdigest()
|
92 |
-
if file_hash != expected_hash:
|
93 |
-
print(f'Hash mismatch for {save_path}. Expected {expected_hash} but got {file_hash}. Deleting checkpoint and trying again.')
|
94 |
-
os.remove(save_path)
|
95 |
-
return load_ckpt(load_from_location, expected_hash)
|
96 |
-
if T.distributed.is_initialized():
|
97 |
-
T.distributed.barrier() # so that ranks don't try to laod checkpoint before it's finished downloading
|
98 |
-
loaded = T.load(f"ckpt/{load_from_location}.pt", weights_only=False, map_location='cpu')
|
99 |
-
return loaded
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
utils/interp.py
DELETED
@@ -1,84 +0,0 @@
|
|
1 |
-
import torch as T
|
2 |
-
import os
|
3 |
-
|
4 |
-
def rank0():
|
5 |
-
rank = os.environ.get('RANK')
|
6 |
-
if rank is None or rank == '0':
|
7 |
-
return True
|
8 |
-
else:
|
9 |
-
return False
|
10 |
-
|
11 |
-
def print_colored(message, color='reset', bold=False, **kwargs):
|
12 |
-
color_dict = {
|
13 |
-
'bold': '\033[1m',
|
14 |
-
'green': '\033[92m',
|
15 |
-
'yellow': '\033[93m',
|
16 |
-
'red': '\033[91m',
|
17 |
-
'blue': '\033[94m',
|
18 |
-
'grey': '\033[90m',
|
19 |
-
'white': '\033[97m',
|
20 |
-
'reset': '\033[0m'
|
21 |
-
}
|
22 |
-
|
23 |
-
color_code = color_dict.get(color.lower(), color_dict['reset'])
|
24 |
-
prefix = color_dict['bold'] if bold else ''
|
25 |
-
print(f"{prefix}{color_code}{message}{color_dict['reset']}", **kwargs)
|
26 |
-
|
27 |
-
def print0_colored(*args, **kwargs):
|
28 |
-
if rank0():
|
29 |
-
print_colored(*args, **kwargs)
|
30 |
-
|
31 |
-
def param_count(module):
|
32 |
-
def count_parameters(model):
|
33 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
34 |
-
|
35 |
-
total_params = count_parameters(module)
|
36 |
-
output = [f'Total model parameters: {total_params:,}', '---------------------------']
|
37 |
-
|
38 |
-
for name, child in module.named_children():
|
39 |
-
params = count_parameters(child)
|
40 |
-
output.append(f'{name} parameters: {params:,}')
|
41 |
-
|
42 |
-
return '\n'.join(output)
|
43 |
-
|
44 |
-
def model_size_estimation(module):
|
45 |
-
def estimate_size(model):
|
46 |
-
param_size = sum(p.nelement() * p.element_size() for p in model.parameters())
|
47 |
-
buffer_size = sum(b.nelement() * b.element_size() for b in model.buffers())
|
48 |
-
return param_size + buffer_size
|
49 |
-
|
50 |
-
total_size = estimate_size(module)
|
51 |
-
output = [f'Total model size: {total_size / 1024**2:.2f} MB', '---------------------------']
|
52 |
-
|
53 |
-
for name, child in module.named_children():
|
54 |
-
child_size = estimate_size(child)
|
55 |
-
output.append(f'{name} size: {child_size / 1024**2:.2f} MB')
|
56 |
-
|
57 |
-
return '\n'.join(output)
|
58 |
-
|
59 |
-
def layer_param_distribution(module):
|
60 |
-
def count_parameters(model):
|
61 |
-
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
62 |
-
|
63 |
-
def get_layer_types(model):
|
64 |
-
layer_types = {}
|
65 |
-
for name, module in model.named_modules():
|
66 |
-
layer_type = module.__class__.__name__
|
67 |
-
params = sum(p.numel() for p in module.parameters(recurse=False) if p.requires_grad)
|
68 |
-
if params > 0:
|
69 |
-
if layer_type not in layer_types:
|
70 |
-
layer_types[layer_type] = 0
|
71 |
-
layer_types[layer_type] += params
|
72 |
-
return layer_types
|
73 |
-
|
74 |
-
total_params = count_parameters(module)
|
75 |
-
layer_types = get_layer_types(module)
|
76 |
-
|
77 |
-
output = [f'Total trainable parameters: {total_params:,}', '---------------------------']
|
78 |
-
|
79 |
-
for layer_type, count in sorted(layer_types.items(), key=lambda x: x[1], reverse=True):
|
80 |
-
percentage = (count / total_params) * 100
|
81 |
-
output.append(f'{layer_type}: {count:,} ({percentage:.2f}%)')
|
82 |
-
|
83 |
-
return '\n'.join(output)
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|