grupafox commited on
Commit
90ca2ad
·
1 Parent(s): 8d1ea49

Upload 16 files

Browse files
Files changed (16) hide show
  1. .gitignore +12 -0
  2. FUNDING.yml +1 -0
  3. LICENSE.md +33 -0
  4. README.md +190 -3
  5. clone_voice.ipynb +255 -0
  6. generate.ipynb +173 -0
  7. generate_chunked.ipynb +344 -0
  8. model-card.md +40 -0
  9. pyproject.toml +58 -0
  10. rvc_infer.py +169 -0
  11. rvc_test.ipynb +85 -0
  12. setup.py +3 -0
  13. test_models.ipynb +454 -0
  14. train_coarse.ipynb +936 -0
  15. train_fine.ipynb +919 -0
  16. train_semantic.ipynb +899 -0
.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ *.wav
3
+ _temp/
4
+ models/
5
+ wandb/
6
+ *_output/
7
+ output.npz
8
+ joe_biden_state_of_union/
9
+ Retrieval-based-Voice-Conversion-WebUI/
10
+ devin-youtube/
11
+ train_rvc.ipynb
12
+ *.pt
FUNDING.yml ADDED
@@ -0,0 +1 @@
 
 
1
+ github: serp-ai
LICENSE.md ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ MIT License
2
+
3
+ Copyright (c) 2023 [SERP](https://serp.co/) | [SERP AI](https://serp.ai/) | [DS](https://devinschumacher.com/)
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+
23
+ # A humble request
24
+
25
+ Our mission is to make artificial intelligence accessible & enjoyable, so we can all build bridges to the future, together.
26
+
27
+ Please, feel free to use this as you see fit in accordance with the law & ideally inline with our values of accessibility, equality & AI for all.
28
+
29
+ We only have one humble request (not requirement) ... that you represent these values by adding one of our (extremely awesome) AI badges on your website / github / etc.
30
+
31
+ 👉 You can generate & customize your own here: [https://serp.ly/@serpai/badges/ai](https://serp.ly/@serpai/badges/ai)
32
+
33
+ Thank you!
README.md CHANGED
@@ -1,3 +1,190 @@
1
- ---
2
- license: apache-2.0
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # 🐶 BARK AI: but with the ability to use voice cloning on custom audio samples
2
+
3
+ For RVC `git clone https://github.com/RVC-Project/Retrieval-based-Voice-Conversion-WebUI` and train your model or point the code to you model (must clone RVC repo in bark-with-voice-clone directory)
4
+
5
+ If you want to clone a voice just follow the `clone_voice.ipynb` notebook. If you want to generate audio from text, follow the `generate.ipynb` notebook.
6
+
7
+ To create a voice clone sample, you need an audio sample of around 5-12 seconds
8
+
9
+ You will get the best results by making generations with your cloned voice until you find one that is really close to the source. Then use that as the new history prompt (comes from the model so should theoretically be more consistent)
10
+
11
+ - [BARK text to speech @ SERP AI](https://serp.ai/tools/bark-text-to-speech-ai-voice-clone-app/)
12
+
13
+ # Contributors
14
+
15
+ Huge shoutout & thank you to:
16
+
17
+ [gitmylo](https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer/)
18
+ for the solution to the semantic token generation for better voice clones and finetunes (HuBERT, etc.)
19
+
20
+ ***
21
+
22
+ <div style="display: flex; flex-wrap: wrap;">
23
+ <a href="https://github.com/francislabountyjr" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/73464335?v=4" alt="francislabountyjr" style="border-radius: 50%; width: 75px; height: 75px;"></a>
24
+ <a href="https://github.com/gkucsko" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/5068315?v=4" alt="gkucsko" style="border-radius: 50%; width: 75px; height: 75px;"></a>
25
+ <a href="https://github.com/kmfreyberg" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/32879321?v=4" alt="kmfreyberg" style="border-radius: 50%; width: 75px; height: 75px;"></a>
26
+ <a href="https://github.com/Vaibhavs10" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/18682411?v=4" alt="Vaibhavs10" style="border-radius: 50%; width: 75px; height: 75px;"></a>
27
+ <a href="https://github.com/devinschumacher" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/45643901?v=4" alt="devinschumacher" style="border-radius: 50%; width: 75px; height: 75px;"></a>
28
+ <a href="https://github.com/mcamac" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/461009?v=4" alt="mcamac" style="border-radius: 50%; width: 75px; height: 75px;"></a>
29
+ <a href="https://github.com/fiq" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/236293?v=4" alt="fiq" style="border-radius: 50%; width: 75px; height: 75px;"></a>
30
+ <a href="https://github.com/zygi" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/2059901?v=4" alt="zygi" style="border-radius: 50%; width: 75px; height: 75px;"></a>
31
+ <a href="https://github.com/jn-jairo" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/5104869?v=4" alt="jn-jairo" style="border-radius: 50%; width: 75px; height: 75px;"></a>
32
+ <a href="https://github.com/gitmylo" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/36931363?v=4" alt="gitmylo" style="border-radius: 50%; width: 75px; height: 75px;"></a>
33
+ <a href="https://github.com/alyxdow" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/84633629?v=4" alt="alyxdow" style="border-radius: 50%; width: 75px; height: 75px;"></a>
34
+ <a href="https://github.com/mikeyshulman" target="_blank" style="margin: 5px; display: inline-block;"><img src="https://avatars.githubusercontent.com/u/2565833?v=4" alt="mikeyshulman" style="border-radius: 50%; width: 75px; height: 75px;"></a>
35
+ </div>
36
+
37
+
38
+
39
+
40
+
41
+ -------------------------------------------------------------------
42
+ # Original README.md
43
+ ## 🤖 Usage
44
+
45
+ ```python
46
+ from bark import SAMPLE_RATE, generate_audio, preload_models
47
+ from IPython.display import Audio
48
+
49
+ # download and load all models
50
+ preload_models()
51
+
52
+ # generate audio from text
53
+ text_prompt = """
54
+ Hello, my name is Serpy. And, uh — and I like pizza. [laughs]
55
+ But I also have other interests such as playing tic tac toe.
56
+ """
57
+ audio_array = generate_audio(text_prompt)
58
+
59
+ # play text in notebook
60
+ Audio(audio_array, rate=SAMPLE_RATE)
61
+ ```
62
+
63
+ [pizza.webm](https://user-images.githubusercontent.com/5068315/230490503-417e688d-5115-4eee-9550-b46a2b465ee3.webm)
64
+
65
+
66
+ To save `audio_array` as a WAV file:
67
+
68
+ ```python
69
+ from scipy.io.wavfile import write as write_wav
70
+
71
+ write_wav("/path/to/audio.wav", SAMPLE_RATE, audio_array)
72
+ ```
73
+
74
+ ### 🌎 Foreign Language
75
+
76
+ Bark supports various languages out-of-the-box and automatically determines language from input text. When prompted with code-switched text, Bark will attempt to employ the native accent for the respective languages. English quality is best for the time being, and we expect other languages to further improve with scaling.
77
+
78
+ ```python
79
+ text_prompt = """
80
+ Buenos días Miguel. Tu colega piensa que tu alemán es extremadamente malo.
81
+ But I suppose your english isn't terrible.
82
+ """
83
+ audio_array = generate_audio(text_prompt)
84
+ ```
85
+
86
+ [miguel.webm](https://user-images.githubusercontent.com/5068315/230684752-10baadfe-1e7c-46a2-8323-43282aef2c8c.webm)
87
+
88
+ ### 🎶 Music
89
+
90
+ Bark can generate all types of audio, and, in principle, doesn't see a difference between speech and music. Sometimes Bark chooses to generate text as music, but you can help it out by adding music notes around your lyrics.
91
+
92
+ ```python
93
+ text_prompt = """
94
+ ♪ In the jungle, the mighty jungle, the lion barks tonight ♪
95
+ """
96
+ audio_array = generate_audio(text_prompt)
97
+ ```
98
+
99
+ [lion.webm](https://user-images.githubusercontent.com/5068315/230684766-97f5ea23-ad99-473c-924b-66b6fab24289.webm)
100
+
101
+ ### 🎤 Voice Presets and Voice/Audio Cloning
102
+
103
+ Bark has the capability to fully clone voices - including tone, pitch, emotion and prosody. The model also attempts to preserve music, ambient noise, etc. from input audio. However, to mitigate misuse of this technology, we limit the audio history prompts to a limited set of Suno-provided, fully synthetic options to choose from for each language. Specify following the pattern: `{lang_code}_speaker_{0-9}`.
104
+
105
+ ```python
106
+ text_prompt = """
107
+ I have a silky smooth voice, and today I will tell you about
108
+ the exercise regimen of the common sloth.
109
+ """
110
+ audio_array = generate_audio(text_prompt, history_prompt="en_speaker_1")
111
+ ```
112
+
113
+
114
+ [sloth.webm](https://user-images.githubusercontent.com/5068315/230684883-a344c619-a560-4ff5-8b99-b4463a34487b.webm)
115
+
116
+ *Note: since Bark recognizes languages automatically from input text, it is possible to use for example a german history prompt with english text. This usually leads to english audio with a german accent.*
117
+
118
+ ### 👥 Speaker Prompts
119
+
120
+ You can provide certain speaker prompts such as NARRATOR, MAN, WOMAN, etc. Please note that these are not always respected, especially if a conflicting audio history prompt is given.
121
+
122
+ ```python
123
+ text_prompt = """
124
+ WOMAN: I would like an oatmilk latte please.
125
+ MAN: Wow, that's expensive!
126
+ """
127
+ audio_array = generate_audio(text_prompt)
128
+ ```
129
+
130
+ [latte.webm](https://user-images.githubusercontent.com/5068315/230684864-12d101a1-a726-471d-9d56-d18b108efcb8.webm)
131
+
132
+
133
+ ## 💻 Installation
134
+
135
+ ```
136
+ pip install git+https://github.com/suno-ai/bark.git
137
+ ```
138
+
139
+ or
140
+
141
+ ```
142
+ git clone https://github.com/suno-ai/bark
143
+ cd bark && pip install .
144
+ ```
145
+
146
+ ## 🛠️ Hardware and Inference Speed
147
+
148
+ Bark has been tested and works on both CPU and GPU (`pytorch 2.0+`, CUDA 11.7 and CUDA 12.0).
149
+ Running Bark requires running >100M parameter transformer models.
150
+ On modern GPUs and PyTorch nightly, Bark can generate audio in roughly realtime. On older GPUs, default colab, or CPU, inference time might be 10-100x slower.
151
+
152
+ ## ⚙️ Details
153
+
154
+ Similar to [Vall-E](https://arxiv.org/abs/2301.02111) and some other amazing work in the field, Bark uses GPT-style
155
+ models to generate audio from scratch. Different from Vall-E, the initial text prompt is embedded into high-level semantic tokens without the use of phonemes. It can therefore generalize to arbitrary instructions beyond speech that occur in the training data, such as music lyrics, sound effects or other non-speech sounds. A subsequent second model is used to convert the generated semantic tokens into audio codec tokens to generate the full waveform. To enable the community to use Bark via public code we used the fantastic
156
+ [EnCodec codec](https://github.com/facebookresearch/encodec) from Facebook to act as an audio representation.
157
+
158
+ Below is a list of some known non-speech sounds
159
+
160
+ - `[laughter]`
161
+ - `[laughs]`
162
+ - `[sighs]`
163
+ - `[music]`
164
+ - `[gasps]`
165
+ - `[clears throat]`
166
+ - `—` or `...` for hesitations
167
+ - `♪` for song lyrics
168
+ - capitalization for emphasis of a word
169
+ - `MAN/WOMAN:` for bias towards speaker
170
+
171
+ **Supported Languages**
172
+
173
+ | Language | Status |
174
+ | --- | --- |
175
+ | English (en) | ✅ |
176
+ | German (de) | ✅ |
177
+ | Spanish (es) | ✅ |
178
+ | French (fr) | ✅ |
179
+ | Hindi (hi) | ✅ |
180
+ | Italian (it) | ✅ |
181
+ | Japanese (ja) | ✅ |
182
+ | Korean (ko) | ✅ |
183
+ | Polish (pl) | ✅ |
184
+ | Portuguese (pt) | ✅ |
185
+ | Russian (ru) | ✅ |
186
+ | Turkish (tr) | ✅ |
187
+ | Chinese, simplified (zh) | ✅ |
188
+ | Arabic | Coming soon! |
189
+ | Bengali | Coming soon! |
190
+ | Telugu | Coming soon! |
clone_voice.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from bark.generation import load_codec_model, generate_text_semantic\n",
10
+ "from encodec.utils import convert_audio\n",
11
+ "\n",
12
+ "import torchaudio\n",
13
+ "import torch\n",
14
+ "\n",
15
+ "device = 'cuda' # or 'cpu'\n",
16
+ "model = load_codec_model(use_gpu=True if device == 'cuda' else False)"
17
+ ]
18
+ },
19
+ {
20
+ "cell_type": "code",
21
+ "execution_count": null,
22
+ "metadata": {},
23
+ "outputs": [],
24
+ "source": [
25
+ "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
26
+ "from hubert.hubert_manager import HuBERTManager\n",
27
+ "hubert_manager = HuBERTManager()\n",
28
+ "hubert_manager.make_sure_hubert_installed()\n",
29
+ "hubert_manager.make_sure_tokenizer_installed()"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer \n",
39
+ "# Load HuBERT for semantic tokens\n",
40
+ "from hubert.pre_kmeans_hubert import CustomHubert\n",
41
+ "from hubert.customtokenizer import CustomTokenizer\n",
42
+ "\n",
43
+ "# Load the HuBERT model\n",
44
+ "hubert_model = CustomHubert(checkpoint_path='data/models/hubert/hubert.pt').to(device)\n",
45
+ "\n",
46
+ "# Load the CustomTokenizer model\n",
47
+ "tokenizer = CustomTokenizer.load_from_checkpoint('data/models/hubert/tokenizer.pth').to(device) # Automatically uses the right layers"
48
+ ]
49
+ },
50
+ {
51
+ "cell_type": "code",
52
+ "execution_count": null,
53
+ "metadata": {},
54
+ "outputs": [],
55
+ "source": [
56
+ "# Load and pre-process the audio waveform\n",
57
+ "audio_filepath = 'audio.wav' # the audio you want to clone (under 13 seconds)\n",
58
+ "wav, sr = torchaudio.load(audio_filepath)\n",
59
+ "wav = convert_audio(wav, sr, model.sample_rate, model.channels)\n",
60
+ "wav = wav.to(device)"
61
+ ]
62
+ },
63
+ {
64
+ "cell_type": "code",
65
+ "execution_count": null,
66
+ "metadata": {},
67
+ "outputs": [],
68
+ "source": [
69
+ "semantic_vectors = hubert_model.forward(wav, input_sample_hz=model.sample_rate)\n",
70
+ "semantic_tokens = tokenizer.get_token(semantic_vectors)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "code",
75
+ "execution_count": null,
76
+ "metadata": {},
77
+ "outputs": [],
78
+ "source": [
79
+ "# Extract discrete codes from EnCodec\n",
80
+ "with torch.no_grad():\n",
81
+ " encoded_frames = model.encode(wav.unsqueeze(0))\n",
82
+ "codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "# move codes to cpu\n",
92
+ "codes = codes.cpu().numpy()\n",
93
+ "# move semantic tokens to cpu\n",
94
+ "semantic_tokens = semantic_tokens.cpu().numpy()"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "import numpy as np\n",
104
+ "voice_name = 'output' # whatever you want the name of the voice to be\n",
105
+ "output_path = 'bark/assets/prompts/' + voice_name + '.npz'\n",
106
+ "np.savez(output_path, fine_prompt=codes, coarse_prompt=codes[:2, :], semantic_prompt=semantic_tokens)"
107
+ ]
108
+ },
109
+ {
110
+ "cell_type": "code",
111
+ "execution_count": null,
112
+ "metadata": {},
113
+ "outputs": [],
114
+ "source": [
115
+ "# That's it! Now you can head over to the generate.ipynb and use your voice_name for the 'history_prompt'"
116
+ ]
117
+ },
118
+ {
119
+ "cell_type": "code",
120
+ "execution_count": null,
121
+ "metadata": {},
122
+ "outputs": [],
123
+ "source": []
124
+ },
125
+ {
126
+ "cell_type": "code",
127
+ "execution_count": null,
128
+ "metadata": {},
129
+ "outputs": [],
130
+ "source": [
131
+ "# Heres the generation stuff copy-pasted for convenience"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": null,
137
+ "metadata": {},
138
+ "outputs": [],
139
+ "source": [
140
+ "from bark.api import generate_audio\n",
141
+ "from transformers import BertTokenizer\n",
142
+ "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic\n",
143
+ "\n",
144
+ "# Enter your prompt and speaker here\n",
145
+ "text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
146
+ "voice_name = \"output\" # use your custom voice name here if you have one"
147
+ ]
148
+ },
149
+ {
150
+ "cell_type": "code",
151
+ "execution_count": null,
152
+ "metadata": {},
153
+ "outputs": [],
154
+ "source": [
155
+ "# download and load all models\n",
156
+ "preload_models(\n",
157
+ " text_use_gpu=True,\n",
158
+ " text_use_small=False,\n",
159
+ " coarse_use_gpu=True,\n",
160
+ " coarse_use_small=False,\n",
161
+ " fine_use_gpu=True,\n",
162
+ " fine_use_small=False,\n",
163
+ " codec_use_gpu=True,\n",
164
+ " force_reload=False,\n",
165
+ " path=\"models\"\n",
166
+ ")"
167
+ ]
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "execution_count": null,
172
+ "metadata": {},
173
+ "outputs": [],
174
+ "source": [
175
+ "# simple generation\n",
176
+ "audio_array = generate_audio(text_prompt, history_prompt=voice_name, text_temp=0.7, waveform_temp=0.7)"
177
+ ]
178
+ },
179
+ {
180
+ "cell_type": "code",
181
+ "execution_count": null,
182
+ "metadata": {},
183
+ "outputs": [],
184
+ "source": [
185
+ "# generation with more control\n",
186
+ "x_semantic = generate_text_semantic(\n",
187
+ " text_prompt,\n",
188
+ " history_prompt=voice_name,\n",
189
+ " temp=0.7,\n",
190
+ " top_k=50,\n",
191
+ " top_p=0.95,\n",
192
+ ")\n",
193
+ "\n",
194
+ "x_coarse_gen = generate_coarse(\n",
195
+ " x_semantic,\n",
196
+ " history_prompt=voice_name,\n",
197
+ " temp=0.7,\n",
198
+ " top_k=50,\n",
199
+ " top_p=0.95,\n",
200
+ ")\n",
201
+ "x_fine_gen = generate_fine(\n",
202
+ " x_coarse_gen,\n",
203
+ " history_prompt=voice_name,\n",
204
+ " temp=0.5,\n",
205
+ ")\n",
206
+ "audio_array = codec_decode(x_fine_gen)"
207
+ ]
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "execution_count": null,
212
+ "metadata": {},
213
+ "outputs": [],
214
+ "source": [
215
+ "from IPython.display import Audio\n",
216
+ "# play audio\n",
217
+ "Audio(audio_array, rate=SAMPLE_RATE)"
218
+ ]
219
+ },
220
+ {
221
+ "cell_type": "code",
222
+ "execution_count": null,
223
+ "metadata": {},
224
+ "outputs": [],
225
+ "source": [
226
+ "from scipy.io.wavfile import write as write_wav\n",
227
+ "# save audio\n",
228
+ "filepath = \"/output/audio.wav\" # change this to your desired output path\n",
229
+ "write_wav(filepath, SAMPLE_RATE, audio_array)"
230
+ ]
231
+ }
232
+ ],
233
+ "metadata": {
234
+ "kernelspec": {
235
+ "display_name": "Python 3",
236
+ "language": "python",
237
+ "name": "python3"
238
+ },
239
+ "language_info": {
240
+ "codemirror_mode": {
241
+ "name": "ipython",
242
+ "version": 3
243
+ },
244
+ "file_extension": ".py",
245
+ "mimetype": "text/x-python",
246
+ "name": "python",
247
+ "nbconvert_exporter": "python",
248
+ "pygments_lexer": "ipython3",
249
+ "version": "3.10.8"
250
+ },
251
+ "orig_nbformat": 4
252
+ },
253
+ "nbformat": 4,
254
+ "nbformat_minor": 2
255
+ }
generate.ipynb ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from IPython.display import Audio\n",
10
+ "from scipy.io.wavfile import write as write_wav\n",
11
+ "\n",
12
+ "from bark.api import generate_audio\n",
13
+ "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "semantic_path = \"semantic_output/pytorch_model.bin\" # set to None if you don't want to use finetuned semantic\n",
23
+ "coarse_path = \"coarse_output/pytorch_model.bin\" # set to None if you don't want to use finetuned coarse\n",
24
+ "fine_path = \"fine_output/pytorch_model.bin\" # set to None if you don't want to use finetuned fine\n",
25
+ "use_rvc = True # Set to False to use bark without RVC\n",
26
+ "rvc_name = 'mi-test'\n",
27
+ "rvc_path = f\"Retrieval-based-Voice-Conversion-WebUI/weights/{rvc_name}.pth\"\n",
28
+ "index_path = f\"Retrieval-based-Voice-Conversion-WebUI/logs/{rvc_name}/added_IVF256_Flat_nprobe_1_{rvc_name}_v2.index\"\n",
29
+ "device=\"cuda:0\"\n",
30
+ "is_half=True"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "# download and load all models\n",
40
+ "preload_models(\n",
41
+ " text_use_gpu=True,\n",
42
+ " text_use_small=False,\n",
43
+ " text_model_path=semantic_path,\n",
44
+ " coarse_use_gpu=True,\n",
45
+ " coarse_use_small=False,\n",
46
+ " coarse_model_path=coarse_path,\n",
47
+ " fine_use_gpu=True,\n",
48
+ " fine_use_small=False,\n",
49
+ " fine_model_path=fine_path,\n",
50
+ " codec_use_gpu=True,\n",
51
+ " force_reload=False,\n",
52
+ " path=\"models\"\n",
53
+ ")\n",
54
+ "\n",
55
+ "if use_rvc:\n",
56
+ " from rvc_infer import get_vc, vc_single\n",
57
+ " get_vc(rvc_path, device, is_half)"
58
+ ]
59
+ },
60
+ {
61
+ "cell_type": "code",
62
+ "execution_count": null,
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# simple generation\n",
67
+ "text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
68
+ "voice_name = \"speaker_0\" # use your custom voice name here if you have on\n",
69
+ "\n",
70
+ "filepath = \"output/audio.wav\"\n",
71
+ "audio_array = generate_audio(text_prompt, history_prompt=voice_name, text_temp=0.7, waveform_temp=0.7)\n",
72
+ "write_wav(filepath, SAMPLE_RATE, audio_array)\n",
73
+ "\n",
74
+ "if use_rvc:\n",
75
+ " index_rate = 0.75\n",
76
+ " f0up_key = -6\n",
77
+ " filter_radius = 3\n",
78
+ " rms_mix_rate = 0.25\n",
79
+ " protect = 0.33\n",
80
+ " resample_sr = SAMPLE_RATE\n",
81
+ " f0method = \"harvest\" #harvest or pm\n",
82
+ " try:\n",
83
+ " audio_array = vc_single(0,filepath,f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
84
+ " except:\n",
85
+ " audio_array = vc_single(0,filepath,f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
86
+ " write_wav(filepath, SAMPLE_RATE, audio_array)\n",
87
+ "\n",
88
+ "Audio(audio_array, rate=SAMPLE_RATE)"
89
+ ]
90
+ },
91
+ {
92
+ "cell_type": "code",
93
+ "execution_count": null,
94
+ "metadata": {},
95
+ "outputs": [],
96
+ "source": [
97
+ "# generation with more control\n",
98
+ "text_prompt = \"Hello, my name is Serpy. And, uh — and I like pizza. [laughs]\"\n",
99
+ "voice_name = \"speaker_0\" # use your custom voice name here if you have on\n",
100
+ "\n",
101
+ "filepath = \"output/audio.wav\"\n",
102
+ "\n",
103
+ "x_semantic = generate_text_semantic(\n",
104
+ " text_prompt,\n",
105
+ " history_prompt=voice_name,\n",
106
+ " temp=0.7,\n",
107
+ " top_k=50,\n",
108
+ " top_p=0.95,\n",
109
+ ")\n",
110
+ "\n",
111
+ "x_coarse_gen = generate_coarse(\n",
112
+ " x_semantic,\n",
113
+ " history_prompt=voice_name,\n",
114
+ " temp=0.7,\n",
115
+ " top_k=50,\n",
116
+ " top_p=0.95,\n",
117
+ ")\n",
118
+ "x_fine_gen = generate_fine(\n",
119
+ " x_coarse_gen,\n",
120
+ " history_prompt=voice_name,\n",
121
+ " temp=0.5,\n",
122
+ ")\n",
123
+ "audio_array = codec_decode(x_fine_gen)\n",
124
+ "write_wav(filepath, SAMPLE_RATE, audio_array)\n",
125
+ "\n",
126
+ "if use_rvc:\n",
127
+ " index_rate = 0.75\n",
128
+ " f0up_key = -6\n",
129
+ " filter_radius = 3\n",
130
+ " rms_mix_rate = 0.25\n",
131
+ " protect = 0.33\n",
132
+ " resample_sr = SAMPLE_RATE\n",
133
+ " f0method = \"harvest\" #harvest or pm\n",
134
+ " try:\n",
135
+ " audio_array = vc_single(0,filepath,f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
136
+ " except:\n",
137
+ " audio_array = vc_single(0,filepath,f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
138
+ " write_wav(filepath, SAMPLE_RATE, audio_array)\n",
139
+ "\n",
140
+ "Audio(audio_array, rate=SAMPLE_RATE)"
141
+ ]
142
+ },
143
+ {
144
+ "cell_type": "code",
145
+ "execution_count": null,
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": []
149
+ }
150
+ ],
151
+ "metadata": {
152
+ "kernelspec": {
153
+ "display_name": "Python 3",
154
+ "language": "python",
155
+ "name": "python3"
156
+ },
157
+ "language_info": {
158
+ "codemirror_mode": {
159
+ "name": "ipython",
160
+ "version": 3
161
+ },
162
+ "file_extension": ".py",
163
+ "mimetype": "text/x-python",
164
+ "name": "python",
165
+ "nbconvert_exporter": "python",
166
+ "pygments_lexer": "ipython3",
167
+ "version": "3.10.8"
168
+ },
169
+ "orig_nbformat": 4
170
+ },
171
+ "nbformat": 4,
172
+ "nbformat_minor": 2
173
+ }
generate_chunked.ipynb ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from IPython.display import Audio\n",
10
+ "from scipy.io.wavfile import write as write_wav\n",
11
+ "\n",
12
+ "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic"
13
+ ]
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "execution_count": null,
18
+ "metadata": {},
19
+ "outputs": [],
20
+ "source": [
21
+ "semantic_path = \"semantic_output/pytorch_model.bin\" # set to None if you don't want to use finetuned semantic\n",
22
+ "coarse_path = \"coarse_output/pytorch_model.bin\" # set to None if you don't want to use finetuned coarse\n",
23
+ "fine_path = \"fine_output/pytorch_model.bin\" # set to None if you don't want to use finetuned fine\n",
24
+ "use_rvc = True # Set to False to use bark without RVC\n",
25
+ "rvc_name = 'mi-test'\n",
26
+ "rvc_path = f\"Retrieval-based-Voice-Conversion-WebUI/weights/{rvc_name}.pth\"\n",
27
+ "index_path = f\"Retrieval-based-Voice-Conversion-WebUI/logs/{rvc_name}/added_IVF256_Flat_nprobe_1_{rvc_name}_v2.index\" \n",
28
+ "device=\"cuda:0\"\n",
29
+ "is_half=True"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "import re\n",
39
+ "def split_and_recombine_text(text, desired_length=100, max_length=150):\n",
40
+ " # from https://github.com/neonbjb/tortoise-tts\n",
41
+ " \"\"\"Split text it into chunks of a desired length trying to keep sentences intact.\"\"\"\n",
42
+ " # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii\n",
43
+ " text = re.sub(r\"\\n\\n+\", \"\\n\", text)\n",
44
+ " text = re.sub(r\"\\s+\", \" \", text)\n",
45
+ " text = re.sub(r\"[“”]\", '\"', text)\n",
46
+ "\n",
47
+ " rv = []\n",
48
+ " in_quote = False\n",
49
+ " current = \"\"\n",
50
+ " split_pos = []\n",
51
+ " pos = -1\n",
52
+ " end_pos = len(text) - 1\n",
53
+ "\n",
54
+ " def seek(delta):\n",
55
+ " nonlocal pos, in_quote, current\n",
56
+ " is_neg = delta < 0\n",
57
+ " for _ in range(abs(delta)):\n",
58
+ " if is_neg:\n",
59
+ " pos -= 1\n",
60
+ " current = current[:-1]\n",
61
+ " else:\n",
62
+ " pos += 1\n",
63
+ " current += text[pos]\n",
64
+ " if text[pos] == '\"':\n",
65
+ " in_quote = not in_quote\n",
66
+ " return text[pos]\n",
67
+ "\n",
68
+ " def peek(delta):\n",
69
+ " p = pos + delta\n",
70
+ " return text[p] if p < end_pos and p >= 0 else \"\"\n",
71
+ "\n",
72
+ " def commit():\n",
73
+ " nonlocal rv, current, split_pos\n",
74
+ " rv.append(current)\n",
75
+ " current = \"\"\n",
76
+ " split_pos = []\n",
77
+ "\n",
78
+ " while pos < end_pos:\n",
79
+ " c = seek(1)\n",
80
+ " # do we need to force a split?\n",
81
+ " if len(current) >= max_length:\n",
82
+ " if len(split_pos) > 0 and len(current) > (desired_length / 2):\n",
83
+ " # we have at least one sentence and we are over half the desired length, seek back to the last split\n",
84
+ " d = pos - split_pos[-1]\n",
85
+ " seek(-d)\n",
86
+ " else:\n",
87
+ " # no full sentences, seek back until we are not in the middle of a word and split there\n",
88
+ " while c not in \"!?.\\n \" and pos > 0 and len(current) > desired_length:\n",
89
+ " c = seek(-1)\n",
90
+ " commit()\n",
91
+ " # check for sentence boundaries\n",
92
+ " elif not in_quote and (c in \"!?\\n\" or (c == \".\" and peek(1) in \"\\n \")):\n",
93
+ " # seek forward if we have consecutive boundary markers but still within the max length\n",
94
+ " while (\n",
95
+ " pos < len(text) - 1 and len(current) < max_length and peek(1) in \"!?.\"\n",
96
+ " ):\n",
97
+ " c = seek(1)\n",
98
+ " split_pos.append(pos)\n",
99
+ " if len(current) >= desired_length:\n",
100
+ " commit()\n",
101
+ " # treat end of quote as a boundary if its followed by a space or newline\n",
102
+ " elif in_quote and peek(1) == '\"' and peek(2) in \"\\n \":\n",
103
+ " seek(2)\n",
104
+ " split_pos.append(pos)\n",
105
+ " rv.append(current)\n",
106
+ "\n",
107
+ " # clean up, remove lines with only whitespace or punctuation\n",
108
+ " rv = [s.strip() for s in rv]\n",
109
+ " rv = [s for s in rv if len(s) > 0 and not re.match(r\"^[\\s\\.,;:!?]*$\", s)]\n",
110
+ "\n",
111
+ " return rv\n",
112
+ "\n",
113
+ "def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, semantic_top_p=0.95, coarse_temp=0.7, coarse_top_k=50, coarse_top_p=0.95, fine_temp=0.5, voice_name=None, use_semantic_history_prompt=True, use_coarse_history_prompt=True, use_fine_history_prompt=True, output_full=False):\n",
114
+ " # generation with more control\n",
115
+ " x_semantic = generate_text_semantic(\n",
116
+ " text_prompt,\n",
117
+ " history_prompt=voice_name if use_semantic_history_prompt else None,\n",
118
+ " temp=semantic_temp,\n",
119
+ " top_k=semantic_top_k,\n",
120
+ " top_p=semantic_top_p,\n",
121
+ " )\n",
122
+ "\n",
123
+ " x_coarse_gen = generate_coarse(\n",
124
+ " x_semantic,\n",
125
+ " history_prompt=voice_name if use_coarse_history_prompt else None,\n",
126
+ " temp=coarse_temp,\n",
127
+ " top_k=coarse_top_k,\n",
128
+ " top_p=coarse_top_p,\n",
129
+ " )\n",
130
+ " x_fine_gen = generate_fine(\n",
131
+ " x_coarse_gen,\n",
132
+ " history_prompt=voice_name if use_fine_history_prompt else None,\n",
133
+ " temp=fine_temp,\n",
134
+ " )\n",
135
+ "\n",
136
+ " if output_full:\n",
137
+ " full_generation = {\n",
138
+ " 'semantic_prompt': x_semantic,\n",
139
+ " 'coarse_prompt': x_coarse_gen,\n",
140
+ " 'fine_prompt': x_fine_gen,\n",
141
+ " }\n",
142
+ " return full_generation, codec_decode(x_fine_gen)\n",
143
+ " return codec_decode(x_fine_gen)"
144
+ ]
145
+ },
146
+ {
147
+ "cell_type": "code",
148
+ "execution_count": null,
149
+ "metadata": {},
150
+ "outputs": [],
151
+ "source": [
152
+ "# `[laughter]`\n",
153
+ "# - `[laughs]`\n",
154
+ "# - `[sighs]`\n",
155
+ "# - `[music]`\n",
156
+ "# - `[gasps]`\n",
157
+ "# - `[clears throat]`\n",
158
+ "# - `—` or `...` for hesitations\n",
159
+ "# - `♪` for song lyrics"
160
+ ]
161
+ },
162
+ {
163
+ "cell_type": "code",
164
+ "execution_count": null,
165
+ "metadata": {},
166
+ "outputs": [],
167
+ "source": [
168
+ "# download and load all models\n",
169
+ "preload_models(\n",
170
+ " text_use_gpu=True,\n",
171
+ " text_use_small=False,\n",
172
+ " text_model_path=semantic_path,\n",
173
+ " coarse_use_gpu=True,\n",
174
+ " coarse_use_small=False,\n",
175
+ " coarse_model_path=coarse_path,\n",
176
+ " fine_use_gpu=True,\n",
177
+ " fine_use_small=False,\n",
178
+ " fine_model_path=fine_path,\n",
179
+ " codec_use_gpu=True,\n",
180
+ " force_reload=False,\n",
181
+ " path=\"models\"\n",
182
+ ")\n",
183
+ "\n",
184
+ "if use_rvc:\n",
185
+ " from rvc_infer import get_vc, vc_single\n",
186
+ " get_vc(rvc_path, device, is_half)"
187
+ ]
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "execution_count": null,
192
+ "metadata": {},
193
+ "outputs": [],
194
+ "source": [
195
+ "text = \"\"\"The Uncharted Land of Discovery: A Journey Through Time and Space\n",
196
+ "[clears throat]\n",
197
+ "Chapter 1: The Dawn of Curiosity\n",
198
+ "[takes breath]\n",
199
+ "Since the dawn of humankind, our species has been driven by a powerful force: curiosity. It is an innate, unquenchable desire to explore, understand, and unravel the mysteries of the world around us. This primal urge has led us on countless adventures, pushing us to the farthest reaches of our planet and beyond.\n",
200
+ "\n",
201
+ "Early humans, huddled around a flickering fire, gazed up at the night sky and wondered what those twinkling lights were. They had no idea that their curiosity would eventually propel us into the vast, uncharted realm of space. As time progressed, our ancestors began to explore their surroundings, venturing beyond their caves and settlements, driven by the need to discover what lay beyond the horizon.\n",
202
+ "\n",
203
+ "hapter 2: The Age of Exploration\n",
204
+ "\n",
205
+ "The Age of Exploration marked a turning point in human history, as brave souls took to the seas in search of new lands, wealth, and knowledge. Pioneers like Christopher Columbus, Vasco da Gama, and Ferdinand Magellan set sail on perilous voyages, pushing the boundaries of what was known and understood.\n",
206
+ "[clears throat]\n",
207
+ "These intrepid explorers discovered new continents, mapped out previously unknown territories, and encountered diverse cultures. They also established trade routes, allowing for the exchange of goods, ideas, and innovations between distant societies. The Age of Exploration was not without its dark moments, however, as conquest, colonization, and exploitation often went hand in hand with discovery.\n",
208
+ "[clears throat]\n",
209
+ "Chapter 3: The Scientific Revolution\n",
210
+ "[laughs]\n",
211
+ "The Scientific Revolution was a period of profound change, as humanity began to question long-held beliefs and seek empirical evidence. Pioneers like Galileo Galilei, Isaac Newton, and Johannes Kepler sought to understand the natural world through observation, experimentation, and reason.\n",
212
+ "[sighs]\n",
213
+ "Their discoveries laid the foundation for modern science, transforming the way we view the universe and our place within it. New technologies, such as the telescope and the microscope, allowed us to peer deeper into the cosmos and the microscopic world, further expanding our understanding of reality.\n",
214
+ "[gasps]\n",
215
+ "Chapter 4: The Information Age\n",
216
+ "\n",
217
+ "The Information Age, sometimes referred to as the Digital Age, has revolutionized the way we communicate, learn, and access knowledge. With the advent of the internet and personal computers, information that was once reserved for the privileged few is now available to the masses.\n",
218
+ "...\n",
219
+ "This democratization of knowledge has led to an explosion of innovation, as ideas and information are shared across borders and cultures at lightning speed. The Information Age has also brought new challenges, as the rapid pace of technological advancements threatens to outpace our ability to adapt and raises questions about the ethical implications of our increasingly interconnected world.\n",
220
+ "[laughter]\n",
221
+ "Chapter 5: The Final Frontier\n",
222
+ "[clears throat]\n",
223
+ "As our knowledge of the universe expands, so too does our desire to explore the cosmos. Space exploration has come a long way since the first successful satellite, Sputnik, was launched in 1957. We have landed humans on the moon, sent probes to the far reaches of our solar system, and even glimpsed distant galaxies through powerful telescopes.\n",
224
+ "\n",
225
+ "The future of space exploration is filled with possibilities, from establishing colonies on Mars to the search for extraterrestrial life. As we venture further into the unknown, we continue to be driven by the same curiosity that has propelled us throughout history, always seeking to uncover the secrets of the universe and our place within it.\n",
226
+ "...\n",
227
+ "In conclusion, the human journey is one of discovery, driven by our innate curiosity and desire to understand the world around us. From the dawn of our species to the present day, we have continued to explore, learn, and adapt, pushing the boundaries of what is known and possible. As we continue to unravel the mysteries of the cosmos, our spirit.\"\"\""
228
+ ]
229
+ },
230
+ {
231
+ "cell_type": "code",
232
+ "execution_count": null,
233
+ "metadata": {},
234
+ "outputs": [],
235
+ "source": [
236
+ "# Chunk the text into smaller pieces then combine the generated audio\n",
237
+ "from time import time\n",
238
+ "from tqdm.auto import tqdm\n",
239
+ "from IPython.display import Audio\n",
240
+ "from scipy.io.wavfile import write as write_wav\n",
241
+ "import os\n",
242
+ "import numpy as np\n",
243
+ "\n",
244
+ "# generation settings\n",
245
+ "voice_name = 'en_speaker_0'\n",
246
+ "out_filepath = 'audio/audio.wav'\n",
247
+ "\n",
248
+ "semantic_temp = 0.7\n",
249
+ "semantic_top_k = 50\n",
250
+ "semantic_top_p = 0.95\n",
251
+ "\n",
252
+ "coarse_temp = 0.7\n",
253
+ "coarse_top_k = 50\n",
254
+ "coarse_top_p = 0.95\n",
255
+ "\n",
256
+ "fine_temp = 0.5\n",
257
+ "\n",
258
+ "use_semantic_history_prompt = True\n",
259
+ "use_coarse_history_prompt = True\n",
260
+ "use_fine_history_prompt = True\n",
261
+ "\n",
262
+ "use_last_generation_as_history = True\n",
263
+ "\n",
264
+ "if use_rvc:\n",
265
+ " index_rate = 0.75\n",
266
+ " f0up_key = -10\n",
267
+ " filter_radius = 3\n",
268
+ " rms_mix_rate = 0.25\n",
269
+ " protect = 0.33\n",
270
+ " resample_sr = SAMPLE_RATE\n",
271
+ " f0method = \"harvest\" #harvest or pm\n",
272
+ "\n",
273
+ "texts = split_and_recombine_text(text)\n",
274
+ "\n",
275
+ "all_parts = []\n",
276
+ "for i, text in tqdm(enumerate(texts), total=len(texts)):\n",
277
+ " full_generation, audio_array = generate_with_settings(\n",
278
+ " text,\n",
279
+ " semantic_temp=semantic_temp,\n",
280
+ " semantic_top_k=semantic_top_k,\n",
281
+ " semantic_top_p=semantic_top_p,\n",
282
+ " coarse_temp=coarse_temp,\n",
283
+ " coarse_top_k=coarse_top_k,\n",
284
+ " coarse_top_p=coarse_top_p,\n",
285
+ " fine_temp=fine_temp,\n",
286
+ " voice_name=voice_name,\n",
287
+ " use_semantic_history_prompt=use_semantic_history_prompt,\n",
288
+ " use_coarse_history_prompt=use_coarse_history_prompt,\n",
289
+ " use_fine_history_prompt=use_fine_history_prompt,\n",
290
+ " output_full=True\n",
291
+ " )\n",
292
+ " if use_last_generation_as_history:\n",
293
+ " # save to npz\n",
294
+ " os.makedirs('_temp', exist_ok=True)\n",
295
+ " np.savez_compressed(\n",
296
+ " '_temp/history.npz',\n",
297
+ " semantic_prompt=full_generation['semantic_prompt'],\n",
298
+ " coarse_prompt=full_generation['coarse_prompt'],\n",
299
+ " fine_prompt=full_generation['fine_prompt'],\n",
300
+ " )\n",
301
+ " voice_name = '_temp/history.npz'\n",
302
+ " write_wav(out_filepath.replace('.wav', f'_{i}') + '.wav', SAMPLE_RATE, audio_array)\n",
303
+ "\n",
304
+ " if use_rvc:\n",
305
+ " try:\n",
306
+ " audio_array = vc_single(0,out_filepath.replace('.wav', f'_{i}') + '.wav',f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
307
+ " except:\n",
308
+ " audio_array = vc_single(0,out_filepath.replace('.wav', f'_{i}') + '.wav',f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
309
+ " write_wav(out_filepath.replace('.wav', f'_{i}') + '.wav', SAMPLE_RATE, audio_array)\n",
310
+ " all_parts.append(audio_array)\n",
311
+ "\n",
312
+ "audio_array = np.concatenate(all_parts, axis=-1)\n",
313
+ "\n",
314
+ "# save audio\n",
315
+ "write_wav(out_filepath, SAMPLE_RATE, audio_array)\n",
316
+ "\n",
317
+ "# play audio\n",
318
+ "Audio(audio_array, rate=SAMPLE_RATE)"
319
+ ]
320
+ }
321
+ ],
322
+ "metadata": {
323
+ "kernelspec": {
324
+ "display_name": "Python 3",
325
+ "language": "python",
326
+ "name": "python3"
327
+ },
328
+ "language_info": {
329
+ "codemirror_mode": {
330
+ "name": "ipython",
331
+ "version": 3
332
+ },
333
+ "file_extension": ".py",
334
+ "mimetype": "text/x-python",
335
+ "name": "python",
336
+ "nbconvert_exporter": "python",
337
+ "pygments_lexer": "ipython3",
338
+ "version": "3.10.8"
339
+ },
340
+ "orig_nbformat": 4
341
+ },
342
+ "nbformat": 4,
343
+ "nbformat_minor": 2
344
+ }
model-card.md ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Model Card: Bark
2
+
3
+ This is the official codebase for running the text to audio model, from Suno.ai.
4
+
5
+ The following is additional information about the models released here.
6
+
7
+ ## Model Details
8
+
9
+ Bark is a series of three transformer models that turn text into audio.
10
+ ### Text to semantic tokens
11
+ - Input: text, tokenized with [BERT tokenizer from Hugging Face](https://huggingface.co/docs/transformers/model_doc/bert#transformers.BertTokenizer)
12
+ - Output: semantic tokens that encode the audio to be generated
13
+
14
+ ### Semantic to coarse tokens
15
+ - Input: semantic tokens
16
+ - Output: tokens from the first two codebooks of the [EnCodec Codec](https://github.com/facebookresearch/encodec) from facebook
17
+
18
+ ### Coarse to fine tokens
19
+ - Input: the first two codebooks from EnCodec
20
+ - Output: 8 codebooks from EnCodec
21
+
22
+ ### Architecture
23
+ | Model | Parameters | Attention | Output Vocab size |
24
+ |:-------------------------:|:----------:|------------|:-----------------:|
25
+ | Text to semantic tokens | 80 M | Causal | 10,000 |
26
+ | Semantic to coarse tokens | 80 M | Causal | 2x 1,024 |
27
+ | Coarse to fine tokens | 80 M | Non-causal | 6x 1,024 |
28
+
29
+
30
+ ### Release date
31
+ April 2023
32
+
33
+ ## Broader Implications
34
+ We anticipate that this model's text to audio capabilities can be used to improve accessbility tools in a variety of languages.
35
+ Straightforward improvements will allow models to run faster than realtime, rendering them useful for applications such as virtual assistants.
36
+
37
+ While we hope that this release will enable users to express their creativity and build applications that are a force
38
+ for good, we acknowledge that any text to audio model has the potential for dual use. While it is not straightforward
39
+ to voice clone known people with Bark, they can still be used for nefarious purposes. To further reduce the chances of unintended use of Bark,
40
+ we also release a simple classifier to detect Bark-generated audio with high accuracy (see notebooks section of the main repository).
pyproject.toml ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [project]
6
+ name = "suno-bark"
7
+ version = "0.0.1a"
8
+ description = "Bark text to audio model"
9
+ readme = "README.md"
10
+ requires-python = ">=3.8"
11
+ authors = [
12
+ {name = "Suno Inc", email = "[email protected]"},
13
+ ]
14
+ # Apache 2.0
15
+ license = {file = "LICENSE"}
16
+
17
+ dependencies = [
18
+ "boto3",
19
+ "encodec",
20
+ "funcy",
21
+ "numpy",
22
+ "scipy",
23
+ "tokenizers",
24
+ "torch",
25
+ "tqdm",
26
+ "transformers",
27
+ ]
28
+
29
+ [project.urls]
30
+ source = "https://github.com/suno-ai/bark"
31
+
32
+ [project.optional-dependencies]
33
+ dev = [
34
+ "bandit",
35
+ "black",
36
+ "codecov",
37
+ "flake8",
38
+ "huggingface-hub",
39
+ "hypothesis>=6.14,<7",
40
+ "isort>=5.0.0,<6",
41
+ "jupyter",
42
+ "mypy",
43
+ "nbconvert",
44
+ "nbformat",
45
+ "pydocstyle",
46
+ "pylint",
47
+ "pytest",
48
+ "pytest-cov",
49
+ ]
50
+
51
+ [tool.setuptools]
52
+ packages = ["bark"]
53
+
54
+ [tool.setuptools.package-data]
55
+ bark = ["assets/prompts/*.npz"]
56
+
57
+ [tool.black]
58
+ line-length = 100
rvc_infer.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys,pdb,torch
2
+ now_dir = os.getcwd()
3
+ sys.path.append(now_dir)
4
+ import argparse
5
+ import glob
6
+ import sys
7
+ import torch
8
+ from multiprocessing import cpu_count
9
+ import ffmpeg
10
+ import numpy as np
11
+
12
+
13
+ def load_audio(file, sr):
14
+ try:
15
+ # https://github.com/openai/whisper/blob/main/whisper/audio.py#L26
16
+ # This launches a subprocess to decode audio while down-mixing and resampling as necessary.
17
+ # Requires the ffmpeg CLI and `ffmpeg-python` package to be installed.
18
+ file = (
19
+ file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
20
+ ) # 防止小白拷路径头尾带了空格和"和回车
21
+ out, _ = (
22
+ ffmpeg.input(file, threads=0)
23
+ .output("-", format="f32le", acodec="pcm_f32le", ac=1, ar=sr)
24
+ .run(cmd=["ffmpeg", "-nostdin"], capture_stdout=True, capture_stderr=True)
25
+ )
26
+ except Exception as e:
27
+ raise RuntimeError(f"Failed to load audio: {e}")
28
+
29
+ return np.frombuffer(out, np.float32).flatten()
30
+
31
+
32
+ class Config:
33
+ def __init__(self,device,is_half):
34
+ self.device = device
35
+ self.is_half = is_half
36
+ self.n_cpu = 0
37
+ self.gpu_name = None
38
+ self.gpu_mem = None
39
+ self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
40
+
41
+ def device_config(self) -> tuple:
42
+ if torch.cuda.is_available():
43
+ i_device = int(self.device.split(":")[-1])
44
+ self.gpu_name = torch.cuda.get_device_name(i_device)
45
+ if (
46
+ ("16" in self.gpu_name and "V100" not in self.gpu_name.upper())
47
+ or "P40" in self.gpu_name.upper()
48
+ or "1060" in self.gpu_name
49
+ or "1070" in self.gpu_name
50
+ or "1080" in self.gpu_name
51
+ ):
52
+ print("16系/10系显卡和P40强制单精度")
53
+ self.is_half = False
54
+ for config_file in ["32k.json", "40k.json", "48k.json"]:
55
+ with open(f"configs/{config_file}", "r") as f:
56
+ strr = f.read().replace("true", "false")
57
+ with open(f"configs/{config_file}", "w") as f:
58
+ f.write(strr)
59
+ with open("trainset_preprocess_pipeline_print.py", "r") as f:
60
+ strr = f.read().replace("3.7", "3.0")
61
+ with open("trainset_preprocess_pipeline_print.py", "w") as f:
62
+ f.write(strr)
63
+ else:
64
+ self.gpu_name = None
65
+ self.gpu_mem = int(
66
+ torch.cuda.get_device_properties(i_device).total_memory
67
+ / 1024
68
+ / 1024
69
+ / 1024
70
+ + 0.4
71
+ )
72
+ if self.gpu_mem <= 4:
73
+ with open("trainset_preprocess_pipeline_print.py", "r") as f:
74
+ strr = f.read().replace("3.7", "3.0")
75
+ with open("trainset_preprocess_pipeline_print.py", "w") as f:
76
+ f.write(strr)
77
+ elif torch.backends.mps.is_available():
78
+ print("没有发现支持的N卡, 使用MPS进行推理")
79
+ self.device = "mps"
80
+ else:
81
+ print("没有发现支持的N卡, 使用CPU进行推理")
82
+ self.device = "cpu"
83
+ self.is_half = True
84
+
85
+ if self.n_cpu == 0:
86
+ self.n_cpu = cpu_count()
87
+
88
+ if self.is_half:
89
+ # 6G显存配置
90
+ x_pad = 3
91
+ x_query = 10
92
+ x_center = 60
93
+ x_max = 65
94
+ else:
95
+ # 5G显存配置
96
+ x_pad = 1
97
+ x_query = 6
98
+ x_center = 38
99
+ x_max = 41
100
+
101
+ if self.gpu_mem != None and self.gpu_mem <= 4:
102
+ x_pad = 1
103
+ x_query = 5
104
+ x_center = 30
105
+ x_max = 32
106
+
107
+ return x_pad, x_query, x_center, x_max
108
+
109
+
110
+ now_dir=os.getcwd()
111
+ sys.path.append(now_dir)
112
+ sys.path.append(os.path.join(now_dir,"Retrieval-based-Voice-Conversion-WebUI"))
113
+ from vc_infer_pipeline import VC
114
+ from lib.infer_pack.models import SynthesizerTrnMs256NSFsid, SynthesizerTrnMs256NSFsid_nono, SynthesizerTrnMs768NSFsid, SynthesizerTrnMs768NSFsid_nono
115
+ from fairseq import checkpoint_utils
116
+ from scipy.io import wavfile
117
+
118
+ hubert_model=None
119
+ def load_hubert():
120
+ global hubert_model
121
+ models, saved_cfg, task = checkpoint_utils.load_model_ensemble_and_task(["hubert_base.pt"],suffix="",)
122
+ hubert_model = models[0]
123
+ hubert_model = hubert_model.to(device)
124
+ if(is_half):hubert_model = hubert_model.half()
125
+ else:hubert_model = hubert_model.float()
126
+ hubert_model.eval()
127
+
128
+ def vc_single(sid,input_audio,f0_up_key,f0_file,f0_method,file_index,index_rate,filter_radius=3,resample_sr=48000,rms_mix_rate=0.25, protect=0.33):
129
+ global tgt_sr,net_g,vc,hubert_model
130
+ if input_audio is None:return "You need to upload an audio", None
131
+ f0_up_key = int(f0_up_key)
132
+ audio=load_audio(input_audio,16000)
133
+ times = [0, 0, 0]
134
+ if(hubert_model==None):load_hubert()
135
+ if_f0 = cpt.get("f0", 1)
136
+ version = cpt.get("version")
137
+ audio_opt=vc.pipeline(hubert_model,net_g,sid,audio,input_audio,times,f0_up_key,f0_method,file_index,index_rate,if_f0,filter_radius=filter_radius,tgt_sr=tgt_sr,resample_sr=resample_sr,rms_mix_rate=rms_mix_rate,version=version,protect=protect,f0_file=f0_file)
138
+ # print(times)
139
+ return audio_opt
140
+
141
+
142
+ def get_vc(model_path, device_, is_half_):
143
+ global n_spk,tgt_sr,net_g,vc,cpt,device,is_half
144
+ device = device_
145
+ is_half = is_half_
146
+ config = Config(device, is_half)
147
+ print("loading pth %s"%model_path)
148
+ cpt = torch.load(model_path, map_location="cpu")
149
+ tgt_sr = cpt["config"][-1]
150
+ cpt["config"][-3]=cpt["weight"]["emb_g.weight"].shape[0]#n_spk
151
+ if_f0=cpt.get("f0",1)
152
+ version=cpt.get("version", "v2")
153
+ if(if_f0==1):
154
+ if version == "v1":
155
+ net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=is_half)
156
+ else:
157
+ net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=is_half)
158
+ else:
159
+ if version == "v1":
160
+ net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"])
161
+ else:
162
+ net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"])
163
+ del net_g.enc_q
164
+ print(net_g.load_state_dict(cpt["weight"], strict=False)) # 不加这一行清不干净,真奇葩
165
+ net_g.eval().to(device)
166
+ if (is_half):net_g = net_g.half()
167
+ else:net_g = net_g.float()
168
+ vc = VC(tgt_sr, config)
169
+ n_spk=cpt["config"][-3]
rvc_test.ipynb ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from rvc_infer import get_vc, vc_single"
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": null,
15
+ "metadata": {},
16
+ "outputs": [],
17
+ "source": [
18
+ "model_path = \"Retrieval-based-Voice-Conversion-WebUI/weights/mi-test.pth\"\n",
19
+ "device=\"cuda:0\"\n",
20
+ "is_half=True"
21
+ ]
22
+ },
23
+ {
24
+ "cell_type": "code",
25
+ "execution_count": null,
26
+ "metadata": {},
27
+ "outputs": [],
28
+ "source": [
29
+ "get_vc(model_path, device, is_half)"
30
+ ]
31
+ },
32
+ {
33
+ "cell_type": "code",
34
+ "execution_count": null,
35
+ "metadata": {},
36
+ "outputs": [],
37
+ "source": [
38
+ "index_rate = 0.75\n",
39
+ "f0up_key = -6\n",
40
+ "filter_radius = 3\n",
41
+ "rms_mix_rate = 0.25\n",
42
+ "protect = 0.33\n",
43
+ "resample_sr = 48000\n",
44
+ "f0method = \"harvest\" #harvest or pm\n",
45
+ "input_path = \"output/audio.wav\"\n",
46
+ "index_path = \"Retrieval-based-Voice-Conversion-WebUI/logs/mi-test/added_IVF256_Flat_nprobe_1_mi-test_v2.index\"\n",
47
+ "\n",
48
+ "wav_opt = vc_single(0,input_path,f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)"
49
+ ]
50
+ },
51
+ {
52
+ "cell_type": "code",
53
+ "execution_count": null,
54
+ "metadata": {},
55
+ "outputs": [],
56
+ "source": [
57
+ "from IPython.display import Audio\n",
58
+ "# play audio\n",
59
+ "Audio(wav_opt, rate=48000)"
60
+ ]
61
+ }
62
+ ],
63
+ "metadata": {
64
+ "kernelspec": {
65
+ "display_name": "Python 3",
66
+ "language": "python",
67
+ "name": "python3"
68
+ },
69
+ "language_info": {
70
+ "codemirror_mode": {
71
+ "name": "ipython",
72
+ "version": 3
73
+ },
74
+ "file_extension": ".py",
75
+ "mimetype": "text/x-python",
76
+ "name": "python",
77
+ "nbconvert_exporter": "python",
78
+ "pygments_lexer": "ipython3",
79
+ "version": "3.10.8"
80
+ },
81
+ "orig_nbformat": 4
82
+ },
83
+ "nbformat": 4,
84
+ "nbformat_minor": 2
85
+ }
setup.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from setuptools import setup
2
+
3
+ setup()
test_models.ipynb ADDED
@@ -0,0 +1,454 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "from IPython.display import Audio\n",
10
+ "from scipy.io.wavfile import write as write_wav\n",
11
+ "\n",
12
+ "from bark.api import generate_audio\n",
13
+ "from bark.generation import SAMPLE_RATE, preload_models, codec_decode, generate_coarse, generate_fine, generate_text_semantic"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": null,
19
+ "metadata": {},
20
+ "outputs": [],
21
+ "source": [
22
+ "semantic_path = \"semantic_output/pytorch_model.bin\" # set to None if you don't want to use finetuned semantic\n",
23
+ "coarse_path = \"coarse_output/pytorch_model.bin\" # set to None if you don't want to use finetuned coarse\n",
24
+ "fine_path = \"fine_output/pytorch_model.bin\" # set to None if you don't want to use finetuned fine\n",
25
+ "use_rvc = True # Set to False to use bark without RVC\n",
26
+ "rvc_name = 'mi-test'\n",
27
+ "rvc_path = f\"Retrieval-based-Voice-Conversion-WebUI/weights/{rvc_name}.pth\"\n",
28
+ "index_path = f\"Retrieval-based-Voice-Conversion-WebUI/logs/{rvc_name}/added_IVF256_Flat_nprobe_1_{rvc_name}_v2.index\"\n",
29
+ "device=\"cuda:0\"\n",
30
+ "is_half=True"
31
+ ]
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "execution_count": null,
36
+ "metadata": {},
37
+ "outputs": [],
38
+ "source": [
39
+ "preload_models(\n",
40
+ " text_use_gpu=True,\n",
41
+ " text_use_small=False,\n",
42
+ " text_model_path=semantic_path,\n",
43
+ " coarse_use_gpu=True,\n",
44
+ " coarse_use_small=False,\n",
45
+ " coarse_model_path=coarse_path,\n",
46
+ " fine_use_gpu=True,\n",
47
+ " fine_use_small=False,\n",
48
+ " fine_model_path=fine_path,\n",
49
+ " codec_use_gpu=True,\n",
50
+ " force_reload=False,\n",
51
+ " path=\"models\"\n",
52
+ ")\n",
53
+ "\n",
54
+ "if use_rvc:\n",
55
+ " from rvc_infer import get_vc, vc_single\n",
56
+ " get_vc(rvc_path, device, is_half)"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "metadata": {},
63
+ "outputs": [],
64
+ "source": [
65
+ "# simple generation\n",
66
+ "text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! A lot better than the original!\"\n",
67
+ "filepath = \"output/audio.wav\" # change this to your desired output path\n",
68
+ "audio_array = generate_audio(text_prompt, history_prompt=None, text_temp=0.7, waveform_temp=0.7)\n",
69
+ "write_wav(filepath, SAMPLE_RATE, audio_array)\n",
70
+ "\n",
71
+ "if use_rvc:\n",
72
+ " index_rate = 0.75\n",
73
+ " f0up_key = -6\n",
74
+ " filter_radius = 3\n",
75
+ " rms_mix_rate = 0.25\n",
76
+ " protect = 0.33\n",
77
+ " resample_sr = SAMPLE_RATE\n",
78
+ " f0method = \"harvest\" #harvest or pm\n",
79
+ " try:\n",
80
+ " audio_array = vc_single(0,filepath,f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
81
+ " except:\n",
82
+ " audio_array = vc_single(0,filepath,f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
83
+ " write_wav(filepath, SAMPLE_RATE, audio_array)\n",
84
+ "\n",
85
+ "Audio(audio_array, rate=SAMPLE_RATE)"
86
+ ]
87
+ },
88
+ {
89
+ "cell_type": "code",
90
+ "execution_count": null,
91
+ "metadata": {},
92
+ "outputs": [],
93
+ "source": [
94
+ "def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, semantic_top_p=0.95, coarse_temp=0.7, coarse_top_k=50, coarse_top_p=0.95, fine_temp=0.5, voice_name=None, use_semantic_history_prompt=True, use_coarse_history_prompt=True, use_fine_history_prompt=True, output_full=False):\n",
95
+ " # generation with more control\n",
96
+ " x_semantic = generate_text_semantic(\n",
97
+ " text_prompt,\n",
98
+ " history_prompt=voice_name if use_semantic_history_prompt else None,\n",
99
+ " temp=semantic_temp,\n",
100
+ " top_k=semantic_top_k,\n",
101
+ " top_p=semantic_top_p,\n",
102
+ " )\n",
103
+ "\n",
104
+ " x_coarse_gen = generate_coarse(\n",
105
+ " x_semantic,\n",
106
+ " history_prompt=voice_name if use_coarse_history_prompt else None,\n",
107
+ " temp=coarse_temp,\n",
108
+ " top_k=coarse_top_k,\n",
109
+ " top_p=coarse_top_p,\n",
110
+ " )\n",
111
+ " x_fine_gen = generate_fine(\n",
112
+ " x_coarse_gen,\n",
113
+ " history_prompt=voice_name if use_fine_history_prompt else None,\n",
114
+ " temp=fine_temp,\n",
115
+ " )\n",
116
+ "\n",
117
+ " if output_full:\n",
118
+ " full_generation = {\n",
119
+ " 'semantic_prompt': x_semantic,\n",
120
+ " 'coarse_prompt': x_coarse_gen,\n",
121
+ " 'fine_prompt': x_fine_gen,\n",
122
+ " }\n",
123
+ " return full_generation, codec_decode(x_fine_gen)\n",
124
+ " return codec_decode(x_fine_gen)"
125
+ ]
126
+ },
127
+ {
128
+ "cell_type": "code",
129
+ "execution_count": null,
130
+ "metadata": {},
131
+ "outputs": [],
132
+ "source": [
133
+ "text_prompt = \"I am Joe Biden... and this is the finetuned semantic, coarse and fine model! A lot better than the original!\"\n",
134
+ "filepath = \"output/audio.wav\" # change this to your desired output path\n",
135
+ "\n",
136
+ "audio_array = generate_with_settings(\n",
137
+ " text_prompt,\n",
138
+ " semantic_temp=0.7,\n",
139
+ " semantic_top_k=50,\n",
140
+ " semantic_top_p=0.99,\n",
141
+ " coarse_temp=0.7,\n",
142
+ " coarse_top_k=50,\n",
143
+ " coarse_top_p=0.95,\n",
144
+ " fine_temp=0.5,\n",
145
+ " voice_name=\"datasets/joe_biden_state_of_union/tokens/257.npz\",\n",
146
+ " use_semantic_history_prompt=False,\n",
147
+ " use_coarse_history_prompt=True,\n",
148
+ " use_fine_history_prompt=True,\n",
149
+ " output_full=False\n",
150
+ ")\n",
151
+ "\n",
152
+ "write_wav(filepath, SAMPLE_RATE, audio_array)\n",
153
+ "\n",
154
+ "if use_rvc:\n",
155
+ " index_rate = 0.75\n",
156
+ " f0up_key = -6\n",
157
+ " filter_radius = 3\n",
158
+ " rms_mix_rate = 0.25\n",
159
+ " protect = 0.33\n",
160
+ " resample_sr = SAMPLE_RATE\n",
161
+ " f0method = \"harvest\" #harvest or pm\n",
162
+ " try:\n",
163
+ " audio_array = vc_single(0,filepath,f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
164
+ " except:\n",
165
+ " audio_array = vc_single(0,filepath,f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
166
+ " write_wav(filepath, SAMPLE_RATE, audio_array)\n",
167
+ "\n",
168
+ "Audio(audio_array, rate=SAMPLE_RATE)"
169
+ ]
170
+ },
171
+ {
172
+ "cell_type": "code",
173
+ "execution_count": null,
174
+ "metadata": {},
175
+ "outputs": [],
176
+ "source": []
177
+ },
178
+ {
179
+ "cell_type": "code",
180
+ "execution_count": null,
181
+ "metadata": {},
182
+ "outputs": [],
183
+ "source": [
184
+ "import re\n",
185
+ "def split_and_recombine_text(text, desired_length=100, max_length=150):\n",
186
+ " # from https://github.com/neonbjb/tortoise-tts\n",
187
+ " \"\"\"Split text it into chunks of a desired length trying to keep sentences intact.\"\"\"\n",
188
+ " # normalize text, remove redundant whitespace and convert non-ascii quotes to ascii\n",
189
+ " text = re.sub(r\"\\n\\n+\", \"\\n\", text)\n",
190
+ " text = re.sub(r\"\\s+\", \" \", text)\n",
191
+ " text = re.sub(r\"[“”]\", '\"', text)\n",
192
+ "\n",
193
+ " rv = []\n",
194
+ " in_quote = False\n",
195
+ " current = \"\"\n",
196
+ " split_pos = []\n",
197
+ " pos = -1\n",
198
+ " end_pos = len(text) - 1\n",
199
+ "\n",
200
+ " def seek(delta):\n",
201
+ " nonlocal pos, in_quote, current\n",
202
+ " is_neg = delta < 0\n",
203
+ " for _ in range(abs(delta)):\n",
204
+ " if is_neg:\n",
205
+ " pos -= 1\n",
206
+ " current = current[:-1]\n",
207
+ " else:\n",
208
+ " pos += 1\n",
209
+ " current += text[pos]\n",
210
+ " if text[pos] == '\"':\n",
211
+ " in_quote = not in_quote\n",
212
+ " return text[pos]\n",
213
+ "\n",
214
+ " def peek(delta):\n",
215
+ " p = pos + delta\n",
216
+ " return text[p] if p < end_pos and p >= 0 else \"\"\n",
217
+ "\n",
218
+ " def commit():\n",
219
+ " nonlocal rv, current, split_pos\n",
220
+ " rv.append(current)\n",
221
+ " current = \"\"\n",
222
+ " split_pos = []\n",
223
+ "\n",
224
+ " while pos < end_pos:\n",
225
+ " c = seek(1)\n",
226
+ " # do we need to force a split?\n",
227
+ " if len(current) >= max_length:\n",
228
+ " if len(split_pos) > 0 and len(current) > (desired_length / 2):\n",
229
+ " # we have at least one sentence and we are over half the desired length, seek back to the last split\n",
230
+ " d = pos - split_pos[-1]\n",
231
+ " seek(-d)\n",
232
+ " else:\n",
233
+ " # no full sentences, seek back until we are not in the middle of a word and split there\n",
234
+ " while c not in \"!?.\\n \" and pos > 0 and len(current) > desired_length:\n",
235
+ " c = seek(-1)\n",
236
+ " commit()\n",
237
+ " # check for sentence boundaries\n",
238
+ " elif not in_quote and (c in \"!?\\n\" or (c == \".\" and peek(1) in \"\\n \")):\n",
239
+ " # seek forward if we have consecutive boundary markers but still within the max length\n",
240
+ " while (\n",
241
+ " pos < len(text) - 1 and len(current) < max_length and peek(1) in \"!?.\"\n",
242
+ " ):\n",
243
+ " c = seek(1)\n",
244
+ " split_pos.append(pos)\n",
245
+ " if len(current) >= desired_length:\n",
246
+ " commit()\n",
247
+ " # treat end of quote as a boundary if its followed by a space or newline\n",
248
+ " elif in_quote and peek(1) == '\"' and peek(2) in \"\\n \":\n",
249
+ " seek(2)\n",
250
+ " split_pos.append(pos)\n",
251
+ " rv.append(current)\n",
252
+ "\n",
253
+ " # clean up, remove lines with only whitespace or punctuation\n",
254
+ " rv = [s.strip() for s in rv]\n",
255
+ " rv = [s for s in rv if len(s) > 0 and not re.match(r\"^[\\s\\.,;:!?]*$\", s)]\n",
256
+ "\n",
257
+ " return rv\n",
258
+ "\n",
259
+ "def generate_with_settings(text_prompt, semantic_temp=0.7, semantic_top_k=50, semantic_top_p=0.95, coarse_temp=0.7, coarse_top_k=50, coarse_top_p=0.95, fine_temp=0.5, voice_name=None, use_semantic_history_prompt=True, use_coarse_history_prompt=True, use_fine_history_prompt=True, output_full=False):\n",
260
+ " # generation with more control\n",
261
+ " x_semantic = generate_text_semantic(\n",
262
+ " text_prompt,\n",
263
+ " history_prompt=voice_name if use_semantic_history_prompt else None,\n",
264
+ " temp=semantic_temp,\n",
265
+ " top_k=semantic_top_k,\n",
266
+ " top_p=semantic_top_p,\n",
267
+ " )\n",
268
+ "\n",
269
+ " x_coarse_gen = generate_coarse(\n",
270
+ " x_semantic,\n",
271
+ " history_prompt=voice_name if use_coarse_history_prompt else None,\n",
272
+ " temp=coarse_temp,\n",
273
+ " top_k=coarse_top_k,\n",
274
+ " top_p=coarse_top_p,\n",
275
+ " )\n",
276
+ " x_fine_gen = generate_fine(\n",
277
+ " x_coarse_gen,\n",
278
+ " history_prompt=voice_name if use_fine_history_prompt else None,\n",
279
+ " temp=fine_temp,\n",
280
+ " )\n",
281
+ "\n",
282
+ " if output_full:\n",
283
+ " full_generation = {\n",
284
+ " 'semantic_prompt': x_semantic,\n",
285
+ " 'coarse_prompt': x_coarse_gen,\n",
286
+ " 'fine_prompt': x_fine_gen,\n",
287
+ " }\n",
288
+ " return full_generation, codec_decode(x_fine_gen)\n",
289
+ " return codec_decode(x_fine_gen)"
290
+ ]
291
+ },
292
+ {
293
+ "cell_type": "code",
294
+ "execution_count": null,
295
+ "metadata": {},
296
+ "outputs": [],
297
+ "source": [
298
+ "text = \"\"\"The Uncharted Land of Discovery: A Journey Through Time and Space\n",
299
+ "[clears throat]\n",
300
+ "Chapter 1: The Dawn of Curiosity\n",
301
+ "[takes breath]\n",
302
+ "Since the dawn of humankind, our species has been driven by a powerful force: curiosity. It is an innate, unquenchable desire to explore, understand, and unravel the mysteries of the world around us. This primal urge has led us on countless adventures, pushing us to the farthest reaches of our planet and beyond.\n",
303
+ "\n",
304
+ "Early humans, huddled around a flickering fire, gazed up at the night sky and wondered what those twinkling lights were. They had no idea that their curiosity would eventually propel us into the vast, uncharted realm of space. As time progressed, our ancestors began to explore their surroundings, venturing beyond their caves and settlements, driven by the need to discover what lay beyond the horizon.\n",
305
+ "\n",
306
+ "hapter 2: The Age of Exploration\n",
307
+ "\n",
308
+ "The Age of Exploration marked a turning point in human history, as brave souls took to the seas in search of new lands, wealth, and knowledge. Pioneers like Christopher Columbus, Vasco da Gama, and Ferdinand Magellan set sail on perilous voyages, pushing the boundaries of what was known and understood.\n",
309
+ "[clears throat]\n",
310
+ "These intrepid explorers discovered new continents, mapped out previously unknown territories, and encountered diverse cultures. They also established trade routes, allowing for the exchange of goods, ideas, and innovations between distant societies. The Age of Exploration was not without its dark moments, however, as conquest, colonization, and exploitation often went hand in hand with discovery.\n",
311
+ "[clears throat]\n",
312
+ "Chapter 3: The Scientific Revolution\n",
313
+ "[laughs]\n",
314
+ "The Scientific Revolution was a period of profound change, as humanity began to question long-held beliefs and seek empirical evidence. Pioneers like Galileo Galilei, Isaac Newton, and Johannes Kepler sought to understand the natural world through observation, experimentation, and reason.\n",
315
+ "[sighs]\n",
316
+ "Their discoveries laid the foundation for modern science, transforming the way we view the universe and our place within it. New technologies, such as the telescope and the microscope, allowed us to peer deeper into the cosmos and the microscopic world, further expanding our understanding of reality.\n",
317
+ "[gasps]\n",
318
+ "Chapter 4: The Information Age\n",
319
+ "\n",
320
+ "The Information Age, sometimes referred to as the Digital Age, has revolutionized the way we communicate, learn, and access knowledge. With the advent of the internet and personal computers, information that was once reserved for the privileged few is now available to the masses.\n",
321
+ "...\n",
322
+ "This democratization of knowledge has led to an explosion of innovation, as ideas and information are shared across borders and cultures at lightning speed. The Information Age has also brought new challenges, as the rapid pace of technological advancements threatens to outpace our ability to adapt and raises questions about the ethical implications of our increasingly interconnected world.\n",
323
+ "[laughter]\n",
324
+ "Chapter 5: The Final Frontier\n",
325
+ "[clears throat]\n",
326
+ "As our knowledge of the universe expands, so too does our desire to explore the cosmos. Space exploration has come a long way since the first successful satellite, Sputnik, was launched in 1957. We have landed humans on the moon, sent probes to the far reaches of our solar system, and even glimpsed distant galaxies through powerful telescopes.\n",
327
+ "\n",
328
+ "The future of space exploration is filled with possibilities, from establishing colonies on Mars to the search for extraterrestrial life. As we venture further into the unknown, we continue to be driven by the same curiosity that has propelled us throughout history, always seeking to uncover the secrets of the universe and our place within it.\n",
329
+ "...\n",
330
+ "In conclusion, the human journey is one of discovery, driven by our innate curiosity and desire to understand the world around us. From the dawn of our species to the present day, we have continued to explore, learn, and adapt, pushing the boundaries of what is known and possible. As we continue to unravel the mysteries of the cosmos, our spirit.\"\"\""
331
+ ]
332
+ },
333
+ {
334
+ "cell_type": "code",
335
+ "execution_count": null,
336
+ "metadata": {},
337
+ "outputs": [],
338
+ "source": [
339
+ "# Chunk the text into smaller pieces then combine the generated audio\n",
340
+ "from time import time\n",
341
+ "from tqdm.auto import tqdm\n",
342
+ "from IPython.display import Audio\n",
343
+ "from scipy.io.wavfile import write as write_wav\n",
344
+ "import os\n",
345
+ "import numpy as np\n",
346
+ "\n",
347
+ "# generation settings\n",
348
+ "voice_name = \"datasets/joe_biden_state_of_union/tokens/257.npz\"\n",
349
+ "out_filepath = 'audio/audio.wav'\n",
350
+ "\n",
351
+ "semantic_temp = 0.7\n",
352
+ "semantic_top_k = 100\n",
353
+ "semantic_top_p = 0.99\n",
354
+ "\n",
355
+ "coarse_temp = 0.7\n",
356
+ "coarse_top_k = 100\n",
357
+ "coarse_top_p = 0.95\n",
358
+ "\n",
359
+ "fine_temp = 0.7\n",
360
+ "\n",
361
+ "use_semantic_history_prompt = True\n",
362
+ "use_coarse_history_prompt = True\n",
363
+ "use_fine_history_prompt = True\n",
364
+ "\n",
365
+ "use_last_generation_as_history = False\n",
366
+ "\n",
367
+ "if use_rvc:\n",
368
+ " index_rate = 0.75\n",
369
+ " f0up_key = -6\n",
370
+ " filter_radius = 3\n",
371
+ " rms_mix_rate = 0.25\n",
372
+ " protect = 0.33\n",
373
+ " resample_sr = SAMPLE_RATE\n",
374
+ " f0method = \"harvest\" #harvest or pm\n",
375
+ "\n",
376
+ "texts = split_and_recombine_text(text)\n",
377
+ "\n",
378
+ "all_parts = []\n",
379
+ "for i, text in tqdm(enumerate(texts), total=len(texts)):\n",
380
+ " full_generation, audio_array = generate_with_settings(\n",
381
+ " text,\n",
382
+ " semantic_temp=semantic_temp,\n",
383
+ " semantic_top_k=semantic_top_k,\n",
384
+ " semantic_top_p=semantic_top_p,\n",
385
+ " coarse_temp=coarse_temp,\n",
386
+ " coarse_top_k=coarse_top_k,\n",
387
+ " coarse_top_p=coarse_top_p,\n",
388
+ " fine_temp=fine_temp,\n",
389
+ " voice_name=voice_name,\n",
390
+ " use_semantic_history_prompt=use_semantic_history_prompt,\n",
391
+ " use_coarse_history_prompt=use_coarse_history_prompt,\n",
392
+ " use_fine_history_prompt=use_fine_history_prompt,\n",
393
+ " output_full=True\n",
394
+ " )\n",
395
+ " if use_last_generation_as_history:\n",
396
+ " # save to npz\n",
397
+ " os.makedirs('_temp', exist_ok=True)\n",
398
+ " np.savez_compressed(\n",
399
+ " '_temp/history.npz',\n",
400
+ " semantic_prompt=full_generation['semantic_prompt'],\n",
401
+ " coarse_prompt=full_generation['coarse_prompt'],\n",
402
+ " fine_prompt=full_generation['fine_prompt'],\n",
403
+ " )\n",
404
+ " voice_name = '_temp/history.npz'\n",
405
+ " write_wav(out_filepath.replace('.wav', f'_{i}') + '.wav', SAMPLE_RATE, audio_array)\n",
406
+ "\n",
407
+ " if use_rvc:\n",
408
+ " try:\n",
409
+ " audio_array = vc_single(0,out_filepath.replace('.wav', f'_{i}') + '.wav',f0up_key,None,f0method,index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
410
+ " except:\n",
411
+ " audio_array = vc_single(0,out_filepath.replace('.wav', f'_{i}') + '.wav',f0up_key,None,'pm',index_path,index_rate, filter_radius=filter_radius, resample_sr=resample_sr, rms_mix_rate=rms_mix_rate, protect=protect)\n",
412
+ " write_wav(out_filepath.replace('.wav', f'_{i}') + '.wav', SAMPLE_RATE, audio_array)\n",
413
+ " all_parts.append(audio_array)\n",
414
+ "\n",
415
+ "audio_array = np.concatenate(all_parts, axis=-1)\n",
416
+ "\n",
417
+ "# save audio\n",
418
+ "write_wav(out_filepath, SAMPLE_RATE, audio_array)\n",
419
+ "\n",
420
+ "# play audio\n",
421
+ "Audio(audio_array, rate=SAMPLE_RATE)"
422
+ ]
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "execution_count": null,
427
+ "metadata": {},
428
+ "outputs": [],
429
+ "source": []
430
+ }
431
+ ],
432
+ "metadata": {
433
+ "kernelspec": {
434
+ "display_name": "Python 3",
435
+ "language": "python",
436
+ "name": "python3"
437
+ },
438
+ "language_info": {
439
+ "codemirror_mode": {
440
+ "name": "ipython",
441
+ "version": 3
442
+ },
443
+ "file_extension": ".py",
444
+ "mimetype": "text/x-python",
445
+ "name": "python",
446
+ "nbconvert_exporter": "python",
447
+ "pygments_lexer": "ipython3",
448
+ "version": "3.10.8"
449
+ },
450
+ "orig_nbformat": 4
451
+ },
452
+ "nbformat": 4,
453
+ "nbformat_minor": 2
454
+ }
train_coarse.ipynb ADDED
@@ -0,0 +1,936 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import os\n",
20
+ "import re\n",
21
+ "import gc\n",
22
+ "import math\n",
23
+ "import json\n",
24
+ "import hashlib\n",
25
+ "import numpy as np\n",
26
+ "import logging\n",
27
+ "import torchaudio\n",
28
+ "from tqdm.auto import tqdm\n",
29
+ "import torch.nn.functional as F\n",
30
+ "from encodec.utils import convert_audio\n",
31
+ "from accelerate import Accelerator\n",
32
+ "from accelerate.utils import set_seed\n",
33
+ "from transformers import BertTokenizer\n",
34
+ "from huggingface_hub import hf_hub_download\n",
35
+ "from packaging import version\n",
36
+ "from diffusers.optimization import get_scheduler\n",
37
+ "\n",
38
+ "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
39
+ "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
40
+ "from bark.model import GPTConfig, GPT\n",
41
+ "from bark.model_fine import FineGPT, FineGPTConfig"
42
+ ]
43
+ },
44
+ {
45
+ "attachments": {},
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "# Training Args"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "train_batch_size = 8\n",
59
+ "eval_batch_size = 8\n",
60
+ "grad_accum = 2\n",
61
+ "ckpt_path = 'models/coarse_2.pt'\n",
62
+ "model_type = \"coarse\"\n",
63
+ "dataset_path = 'datasets/joe_biden_state_of_union/'\n",
64
+ "logging_dir = 'logs/'\n",
65
+ "log_with = 'wandb'\n",
66
+ "hubert_path = 'data/models/hubert/hubert.pt'\n",
67
+ "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
68
+ "\n",
69
+ "output_dir = 'coarse_output/'\n",
70
+ "resume_from_checkpoint = None\n",
71
+ "\n",
72
+ "checkpointing_steps = 1000\n",
73
+ "\n",
74
+ "mixed_precision = 'bf16'\n",
75
+ "bits = 16 #4 4 and 8 bit are a work in progress\n",
76
+ "compute_dtype = torch.bfloat16\n",
77
+ "double_quant = True\n",
78
+ "quant_type = 'nf4'\n",
79
+ "\n",
80
+ "lora_dim = 64\n",
81
+ "lora_scaling = 1\n",
82
+ "lora_dropout = 0.1\n",
83
+ "lora_module_name = 'transformer.h'\n",
84
+ "optimize_lora_params_only = False\n",
85
+ "\n",
86
+ "learning_rate = 1e-4\n",
87
+ "scale_lr = False\n",
88
+ "use_8bit_adam = False\n",
89
+ "adam_beta1 = 0.9\n",
90
+ "adam_beta2 = 0.999\n",
91
+ "adam_epsilon = 1e-8\n",
92
+ "weight_decay = 0.01\n",
93
+ "\n",
94
+ "llm_int8_skip_modules = None\n",
95
+ "keep_in_fp32_modules = ['lm_head']\n",
96
+ "\n",
97
+ "lr_scheduler_type = 'linear'\n",
98
+ "lr_warmup_steps = 60\n",
99
+ "num_train_epochs = 5\n",
100
+ "max_train_steps = None\n",
101
+ "max_grad_norm = 1.0\n",
102
+ "\n",
103
+ "semantic_cross_entropy_loss_weight = 0.0\n",
104
+ "\n",
105
+ "seed = 741"
106
+ ]
107
+ },
108
+ {
109
+ "attachments": {},
110
+ "cell_type": "markdown",
111
+ "metadata": {},
112
+ "source": [
113
+ "# Define Functions"
114
+ ]
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "execution_count": null,
119
+ "metadata": {},
120
+ "outputs": [],
121
+ "source": [
122
+ "CONTEXT_WINDOW_SIZE = 1024\n",
123
+ "\n",
124
+ "MAX_SEMANTIC_LEN = 256\n",
125
+ "\n",
126
+ "SEMANTIC_RATE_HZ = 49.9\n",
127
+ "SEMANTIC_VOCAB_SIZE = 10_000\n",
128
+ "\n",
129
+ "TEXT_ENCODING_OFFSET = 10_048\n",
130
+ "SEMANTIC_PAD_TOKEN = 10_000\n",
131
+ "TEXT_PAD_TOKEN = 129_595\n",
132
+ "SEMANTIC_INFER_TOKEN = 129_599\n",
133
+ "\n",
134
+ "MAX_COARSE_LEN = 768\n",
135
+ "\n",
136
+ "SAMPLE_RATE = 24_000\n",
137
+ "CHANNELS = 1\n",
138
+ "\n",
139
+ "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n",
140
+ "COARSE_INFER_TOKEN = 12_050\n",
141
+ "\n",
142
+ "CODEBOOK_SIZE = 1024\n",
143
+ "N_COARSE_CODEBOOKS = 2\n",
144
+ "N_FINE_CODEBOOKS = 8\n",
145
+ "COARSE_RATE_HZ = 75\n",
146
+ "\n",
147
+ "logger = logging.getLogger(__name__)\n",
148
+ "\n",
149
+ "\n",
150
+ "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
151
+ "\n",
152
+ "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
153
+ "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
154
+ "\n",
155
+ "\n",
156
+ "def _clear_cuda_cache():\n",
157
+ " if torch.cuda.is_available():\n",
158
+ " torch.cuda.empty_cache()\n",
159
+ " torch.cuda.synchronize()\n",
160
+ "\n",
161
+ "\n",
162
+ "def _md5(fname):\n",
163
+ " hash_md5 = hashlib.md5()\n",
164
+ " with open(fname, \"rb\") as f:\n",
165
+ " for chunk in iter(lambda: f.read(4096), b\"\"):\n",
166
+ " hash_md5.update(chunk)\n",
167
+ " return hash_md5.hexdigest()\n",
168
+ "\n",
169
+ "\n",
170
+ "def _download(from_hf_path, file_name, to_local_path):\n",
171
+ " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
172
+ " path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
173
+ " os.makedirs(path, exist_ok=True)\n",
174
+ " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
175
+ " os.replace(os.path.join(path, file_name), to_local_path)\n",
176
+ "\n",
177
+ "\n",
178
+ "def _tokenize(tokenizer, text):\n",
179
+ " return tokenizer.encode(text, add_special_tokens=False)\n",
180
+ "\n",
181
+ "\n",
182
+ "def _detokenize(tokenizer, enc_text):\n",
183
+ " return tokenizer.decode(enc_text)\n",
184
+ "\n",
185
+ "\n",
186
+ "def _normalize_whitespace(text):\n",
187
+ " return re.sub(r\"\\s+\", \" \", text).strip()\n",
188
+ "\n",
189
+ "\n",
190
+ "REMOTE_MODEL_PATHS = {\n",
191
+ " \"text_small\": {\n",
192
+ " \"repo_id\": \"suno/bark\",\n",
193
+ " \"file_name\": \"text.pt\",\n",
194
+ " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
195
+ " },\n",
196
+ " \"coarse_small\": {\n",
197
+ " \"repo_id\": \"suno/bark\",\n",
198
+ " \"file_name\": \"coarse.pt\",\n",
199
+ " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
200
+ " },\n",
201
+ " \"fine_small\": {\n",
202
+ " \"repo_id\": \"suno/bark\",\n",
203
+ " \"file_name\": \"fine.pt\",\n",
204
+ " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
205
+ " },\n",
206
+ " \"text\": {\n",
207
+ " \"repo_id\": \"suno/bark\",\n",
208
+ " \"file_name\": \"text_2.pt\",\n",
209
+ " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
210
+ " },\n",
211
+ " \"coarse\": {\n",
212
+ " \"repo_id\": \"suno/bark\",\n",
213
+ " \"file_name\": \"coarse_2.pt\",\n",
214
+ " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
215
+ " },\n",
216
+ " \"fine\": {\n",
217
+ " \"repo_id\": \"suno/bark\",\n",
218
+ " \"file_name\": \"fine_2.pt\",\n",
219
+ " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
220
+ " },\n",
221
+ "}\n",
222
+ "\n",
223
+ "\n",
224
+ "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
225
+ " if model_type == \"text\":\n",
226
+ " ConfigClass = GPTConfig\n",
227
+ " ModelClass = GPT\n",
228
+ " elif model_type == \"coarse\":\n",
229
+ " ConfigClass = GPTConfig\n",
230
+ " ModelClass = GPT\n",
231
+ " elif model_type == \"fine\":\n",
232
+ " ConfigClass = FineGPTConfig\n",
233
+ " ModelClass = FineGPT\n",
234
+ " else:\n",
235
+ " raise NotImplementedError()\n",
236
+ " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
237
+ " model_info = REMOTE_MODEL_PATHS[model_key]\n",
238
+ " if ckpt_path in [None, '']:\n",
239
+ " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
240
+ " if not os.path.exists(ckpt_path):\n",
241
+ " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
242
+ " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
243
+ " checkpoint = torch.load(ckpt_path, map_location=device)\n",
244
+ " # this is a hack\n",
245
+ " model_args = checkpoint[\"model_args\"]\n",
246
+ " if \"input_vocab_size\" not in model_args:\n",
247
+ " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
248
+ " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
249
+ " del model_args[\"vocab_size\"]\n",
250
+ " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
251
+ " model = ModelClass(gptconf)\n",
252
+ " state_dict = checkpoint[\"model\"]\n",
253
+ " # fixup checkpoint\n",
254
+ " unwanted_prefix = \"_orig_mod.\"\n",
255
+ " for k, v in list(state_dict.items()):\n",
256
+ " if k.startswith(unwanted_prefix):\n",
257
+ " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
258
+ " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
259
+ " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
260
+ " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
261
+ " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
262
+ " if len(extra_keys) != 0:\n",
263
+ " raise ValueError(f\"extra keys found: {extra_keys}\")\n",
264
+ " if len(missing_keys) != 0:\n",
265
+ " raise ValueError(f\"missing keys: {missing_keys}\")\n",
266
+ " model.load_state_dict(state_dict, strict=False)\n",
267
+ " n_params = model.get_num_params()\n",
268
+ " val_loss = checkpoint[\"best_val_loss\"].item()\n",
269
+ " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
270
+ " del checkpoint, state_dict\n",
271
+ " _clear_cuda_cache()\n",
272
+ " if model_type == \"text\":\n",
273
+ " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
274
+ " return model, tokenizer\n",
275
+ " return model\n",
276
+ "\n",
277
+ "\n",
278
+ "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n",
279
+ " assert len(arr.shape) == 2\n",
280
+ " arr = arr.copy()\n",
281
+ " if offset_size is not None:\n",
282
+ " for n in range(1, arr.shape[0]):\n",
283
+ " arr[n, :] += offset_size * n\n",
284
+ " flat_arr = arr.ravel(\"F\")\n",
285
+ " return flat_arr\n",
286
+ "\n",
287
+ "\n",
288
+ "def load_filepaths_and_text(filename, split=\"|\"):\n",
289
+ " with open(filename, encoding='utf-8', errors='ignore') as f:\n",
290
+ " filepaths_and_text = [line.strip().split(split) for line in f]\n",
291
+ " base = os.path.dirname(filename)\n",
292
+ " for j in range(len(filepaths_and_text)):\n",
293
+ " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
294
+ " return filepaths_and_text\n",
295
+ "\n",
296
+ "\n",
297
+ "class TtsDataset(torch.utils.data.Dataset):\n",
298
+ " def __init__(self, opt):\n",
299
+ " self.path = os.path.dirname(opt['path'])\n",
300
+ " self.mode = opt['mode']\n",
301
+ " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n",
302
+ "\n",
303
+ " def __getitem__(self, index):\n",
304
+ " audiopath_and_text = self.audiopaths_and_text[index]\n",
305
+ " audiopath = audiopath_and_text[0]\n",
306
+ "\n",
307
+ " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
308
+ " semantic_tokens = tokens['semantic']\n",
309
+ " coarse_tokens = _flatten_codebooks(tokens['coarse'], offset_size=CODEBOOK_SIZE) + SEMANTIC_VOCAB_SIZE\n",
310
+ "\n",
311
+ " return torch.from_numpy(semantic_tokens), torch.from_numpy(coarse_tokens)\n",
312
+ "\n",
313
+ " def __len__(self):\n",
314
+ " return len(self.audiopaths_and_text)\n",
315
+ "\n",
316
+ "\n",
317
+ "class TtsCollater():\n",
318
+ " def __init__(self):\n",
319
+ " pass\n",
320
+ " def __call__(self, batch):\n",
321
+ " max_semantic_len = MAX_SEMANTIC_LEN\n",
322
+ " max_coarse_len = MAX_COARSE_LEN\n",
323
+ " semantic_to_coarse_ratio = COARSE_RATE_HZ / SEMANTIC_RATE_HZ * N_COARSE_CODEBOOKS\n",
324
+ " semantic_tokens = []\n",
325
+ " coarse_tokens = []\n",
326
+ "\n",
327
+ " for b in batch:\n",
328
+ " semantic_tokens_, coarse_tokens_ = b\n",
329
+ " start_idx = None\n",
330
+ " if len(semantic_tokens_) > max_semantic_len:\n",
331
+ " start_idx = np.random.randint(0, len(semantic_tokens_) - max_semantic_len + 1)\n",
332
+ " semantic_tokens_ = semantic_tokens_[start_idx:start_idx+max_semantic_len]\n",
333
+ " semantic_tokens_ = F.pad(semantic_tokens_, (0, max_semantic_len-len(semantic_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n",
334
+ " semantic_tokens_ = torch.cat([semantic_tokens_, torch.tensor([COARSE_INFER_TOKEN])])\n",
335
+ " semantic_tokens.append(semantic_tokens_)\n",
336
+ "\n",
337
+ " if start_idx is not None:\n",
338
+ " start_idx_coarse = int(start_idx * semantic_to_coarse_ratio) \n",
339
+ " coarse_tokens_ = coarse_tokens_[start_idx_coarse:start_idx_coarse+max_coarse_len]\n",
340
+ " coarse_tokens_ = F.pad(coarse_tokens_, (0, max_coarse_len-len(coarse_tokens_)), value=COARSE_SEMANTIC_PAD_TOKEN)\n",
341
+ " coarse_tokens.append(coarse_tokens_)\n",
342
+ "\n",
343
+ " return {\n",
344
+ " 'semantic_tokens': torch.stack(semantic_tokens).contiguous(),\n",
345
+ " 'coarse_tokens': torch.stack(coarse_tokens).contiguous()\n",
346
+ " }\n",
347
+ " \n",
348
+ "\n",
349
+ "accelerator = Accelerator(\n",
350
+ " gradient_accumulation_steps=grad_accum,\n",
351
+ " mixed_precision=mixed_precision,\n",
352
+ " log_with=log_with,\n",
353
+ " logging_dir=logging_dir,\n",
354
+ ")\n",
355
+ "device = accelerator.device\n",
356
+ "\n",
357
+ "os.makedirs(output_dir, exist_ok=True)\n",
358
+ "\n",
359
+ "set_seed(seed)"
360
+ ]
361
+ },
362
+ {
363
+ "attachments": {},
364
+ "cell_type": "markdown",
365
+ "metadata": {},
366
+ "source": [
367
+ "# Setup Dataset (only need to do this once)"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": null,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
377
+ "\n",
378
+ "# path = dataset_path\n",
379
+ "\n",
380
+ "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
381
+ "# from hubert.hubert_manager import HuBERTManager\n",
382
+ "# hubert_manager = HuBERTManager()\n",
383
+ "# from hubert.pre_kmeans_hubert import CustomHubert\n",
384
+ "# from hubert.customtokenizer import CustomTokenizer\n",
385
+ "# hubert_manager.make_sure_hubert_installed()\n",
386
+ "# hubert_manager.make_sure_tokenizer_installed()\n",
387
+ "\n",
388
+ "# # Load the HuBERT model\n",
389
+ "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
390
+ "# hubert_model.eval()\n",
391
+ "# for param in hubert_model.parameters():\n",
392
+ "# param.requires_grad = False\n",
393
+ "\n",
394
+ "# # Load the CustomTokenizer model\n",
395
+ "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
396
+ "\n",
397
+ "# from bark.generation import load_codec_model\n",
398
+ "# codec_model = load_codec_model(use_gpu=True)\n",
399
+ "# codec_model.eval()\n",
400
+ "# for param in codec_model.parameters():\n",
401
+ "# param.requires_grad = False\n",
402
+ "\n",
403
+ "\n",
404
+ "# def get_duration(wav, sr):\n",
405
+ "# return wav.shape[1] / sr\n",
406
+ "\n",
407
+ "# valid_lines_train = []\n",
408
+ "# # convert wavs to semantic tokens\n",
409
+ "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
410
+ "# wav, sr = torchaudio.load(wav_path)\n",
411
+ "# if not get_duration(wav, sr) > max_duration_sec:\n",
412
+ "# valid_lines_train.append((wav_path, txt))\n",
413
+ "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
414
+ "\n",
415
+ "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
416
+ "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
417
+ "\n",
418
+ "# # save semantic tokens\n",
419
+ "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
420
+ "# semantic_tokens = semantic_tokens.cpu().numpy()\n",
421
+ "\n",
422
+ "# # Extract discrete codes from EnCodec\n",
423
+ "# with torch.no_grad():\n",
424
+ "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
425
+ "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
426
+ "\n",
427
+ "# # move codes to cpu\n",
428
+ "# codes = codes.cpu().numpy()\n",
429
+ "\n",
430
+ "# # save tokens\n",
431
+ "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
432
+ "\n",
433
+ "# # rewrite train.txt with valid lines\n",
434
+ "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
435
+ "# for wav_path, txt in valid_lines_train:\n",
436
+ "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
437
+ "# f.write(f'{wav_path}|{txt}\\n')\n",
438
+ "\n",
439
+ "# valid_lines_valid = []\n",
440
+ "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
441
+ "# wav, sr = torchaudio.load(wav_path)\n",
442
+ "# if not get_duration(wav, sr) > max_duration_sec:\n",
443
+ "# valid_lines_valid.append((wav_path, txt))\n",
444
+ "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
445
+ "\n",
446
+ "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
447
+ "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
448
+ "\n",
449
+ "# # save semantic tokens\n",
450
+ "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
451
+ "# semantic_tokens = semantic_tokens.cpu().numpy()\n",
452
+ " \n",
453
+ "# # Extract discrete codes from EnCodec\n",
454
+ "# with torch.no_grad():\n",
455
+ "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
456
+ "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
457
+ "\n",
458
+ "# # move codes to cpu\n",
459
+ "# codes = codes.cpu().numpy()\n",
460
+ "\n",
461
+ "# # save tokens\n",
462
+ "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
463
+ "\n",
464
+ "# # rewrite valid.txt with valid lines\n",
465
+ "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
466
+ "# for wav_path, txt in valid_lines_valid:\n",
467
+ "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
468
+ "# f.write(f'{wav_path}|{txt}\\n')\n",
469
+ "\n",
470
+ "# del hubert_model\n",
471
+ "# del hubert_tokenizer\n",
472
+ "# del codec_model\n",
473
+ "# gc.collect()\n",
474
+ "# torch.cuda.empty_cache()"
475
+ ]
476
+ },
477
+ {
478
+ "attachments": {},
479
+ "cell_type": "markdown",
480
+ "metadata": {},
481
+ "source": [
482
+ "# Setup"
483
+ ]
484
+ },
485
+ {
486
+ "cell_type": "code",
487
+ "execution_count": null,
488
+ "metadata": {},
489
+ "outputs": [],
490
+ "source": [
491
+ "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
492
+ ]
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "execution_count": null,
497
+ "metadata": {},
498
+ "outputs": [],
499
+ "source": [
500
+ "if scale_lr:\n",
501
+ " learning_rate = (\n",
502
+ " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
503
+ " )\n",
504
+ "\n",
505
+ "if use_8bit_adam:\n",
506
+ " try:\n",
507
+ " import bitsandbytes as bnb\n",
508
+ " except ImportError:\n",
509
+ " raise ImportError(\n",
510
+ " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
511
+ " )\n",
512
+ "\n",
513
+ " optimizer_class = bnb.optim.AdamW8bit\n",
514
+ "else:\n",
515
+ " optimizer_class = torch.optim.AdamW"
516
+ ]
517
+ },
518
+ {
519
+ "cell_type": "code",
520
+ "execution_count": null,
521
+ "metadata": {},
522
+ "outputs": [],
523
+ "source": [
524
+ "quantization_config=BitsAndBytesConfig(\n",
525
+ " load_in_4bit=bits == 4,\n",
526
+ " load_in_8bit=bits == 8,\n",
527
+ " llm_int8_threshold=6.0,\n",
528
+ " llm_int8_has_fp16_weight=False,\n",
529
+ " bnb_4bit_compute_dtype=compute_dtype,\n",
530
+ " bnb_4bit_use_double_quant=double_quant,\n",
531
+ " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
532
+ ")\n",
533
+ "\n",
534
+ "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
535
+ "# if quantization_config.load_in_8bit:\n",
536
+ "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
537
+ "# elif quantization_config.load_in_4bit:\n",
538
+ "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
539
+ "\n",
540
+ "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
541
+ "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
542
+ "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
543
+ "# else:\n",
544
+ "# modules_to_not_convert = llm_int8_skip_modules\n",
545
+ "\n",
546
+ "# if not isinstance(modules_to_not_convert, list):\n",
547
+ "# modules_to_not_convert = [modules_to_not_convert]\n",
548
+ "\n",
549
+ "# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
550
+ "\n",
551
+ "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
552
+ "\n",
553
+ "# if quantization_config.load_in_4bit and not supports_4bit:\n",
554
+ "# raise ValueError(\n",
555
+ "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
556
+ "# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
557
+ "# )\n",
558
+ " \n",
559
+ "# if len(modules_to_not_convert) == 0:\n",
560
+ "# modules_to_not_convert = None\n",
561
+ "\n",
562
+ "# model = replace_with_bnb_linear(\n",
563
+ "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
564
+ "# )\n",
565
+ "\n",
566
+ "# # training in 8-bit is only available in 0.37.0+\n",
567
+ "# model._is_kbit_training_enabled = version.parse(\n",
568
+ "# importlib_metadata.version(\"bitsandbytes\")\n",
569
+ "# ) >= version.parse(\"0.37.0\")\n",
570
+ "\n",
571
+ "# model.config.quantization_config = quantization_config"
572
+ ]
573
+ },
574
+ {
575
+ "cell_type": "code",
576
+ "execution_count": null,
577
+ "metadata": {},
578
+ "outputs": [],
579
+ "source": [
580
+ "if bits == 4:\n",
581
+ " from accelerate.utils import CustomDtype\n",
582
+ " target_dtype = CustomDtype.INT4\n",
583
+ "elif bits == 8:\n",
584
+ " target_dtype = torch.int8\n",
585
+ "\n",
586
+ "if lora_dim > 0:\n",
587
+ " for param in model.parameters():\n",
588
+ " if param.ndim == 1:\n",
589
+ " # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
590
+ " param.data = param.data.to(torch.float32)\n",
591
+ " \n",
592
+ " class CastOutputToFloat(nn.Sequential):\n",
593
+ " def forward(self, x):\n",
594
+ " return super().forward(x).to(torch.float32)\n",
595
+ "\n",
596
+ " model.lm_head = CastOutputToFloat(model.lm_head)\n",
597
+ "\n",
598
+ " model = convert_linear_layer_to_lora(model, lora_module_name,\n",
599
+ " lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
600
+ " lora_dropout=lora_dropout)\n",
601
+ " if optimize_lora_params_only:\n",
602
+ " model = only_optimize_lora_parameters(model)"
603
+ ]
604
+ },
605
+ {
606
+ "cell_type": "code",
607
+ "execution_count": null,
608
+ "metadata": {},
609
+ "outputs": [],
610
+ "source": [
611
+ "params_to_optimize = (\n",
612
+ " param for param in model.parameters() if param.requires_grad\n",
613
+ " )\n",
614
+ "\n",
615
+ "optimizer = optimizer_class(\n",
616
+ " params_to_optimize,\n",
617
+ " lr=learning_rate,\n",
618
+ " betas=(adam_beta1, adam_beta2),\n",
619
+ " weight_decay=weight_decay,\n",
620
+ " eps=adam_epsilon,\n",
621
+ ")"
622
+ ]
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "execution_count": null,
627
+ "metadata": {},
628
+ "outputs": [],
629
+ "source": [
630
+ "opt_train = {\n",
631
+ " 'path': dataset_path,\n",
632
+ " 'mode': 'train',\n",
633
+ "}\n",
634
+ "\n",
635
+ "opt_val = {\n",
636
+ " 'path': dataset_path,\n",
637
+ " 'mode': 'valid',\n",
638
+ "}\n",
639
+ "\n",
640
+ "train_dataset = TtsDataset(opt_train)\n",
641
+ "validation_dataset = TtsDataset(opt_val)\n",
642
+ "\n",
643
+ "train_dataloader = torch.utils.data.DataLoader(\n",
644
+ " train_dataset,\n",
645
+ " batch_size=train_batch_size,\n",
646
+ " collate_fn=TtsCollater(),\n",
647
+ ")\n",
648
+ "\n",
649
+ "validation_dataloader = torch.utils.data.DataLoader(\n",
650
+ " validation_dataset,\n",
651
+ " batch_size=eval_batch_size,\n",
652
+ " collate_fn=TtsCollater(),\n",
653
+ ")\n",
654
+ "\n",
655
+ "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n",
656
+ "\n",
657
+ "# Scheduler and math around the number of training steps.\n",
658
+ "overrode_max_train_steps = False\n",
659
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
660
+ "if max_train_steps is None:\n",
661
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
662
+ " overrode_max_train_steps = True\n",
663
+ "\n",
664
+ "lr_scheduler = get_scheduler(\n",
665
+ " lr_scheduler_type,\n",
666
+ " optimizer=optimizer,\n",
667
+ " num_warmup_steps=lr_warmup_steps * grad_accum,\n",
668
+ " num_training_steps=max_train_steps * grad_accum,\n",
669
+ ")"
670
+ ]
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "execution_count": null,
675
+ "metadata": {},
676
+ "outputs": [],
677
+ "source": [
678
+ "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
679
+ " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
680
+ ")\n",
681
+ "accelerator.register_for_checkpointing(lr_scheduler)\n",
682
+ "\n",
683
+ "weight_dtype = torch.float32\n",
684
+ "if accelerator.mixed_precision == \"fp16\":\n",
685
+ " weight_dtype = torch.float16\n",
686
+ "elif accelerator.mixed_precision == \"bf16\":\n",
687
+ " weight_dtype = torch.bfloat16"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": null,
693
+ "metadata": {},
694
+ "outputs": [],
695
+ "source": [
696
+ "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
697
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
698
+ "if overrode_max_train_steps:\n",
699
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
700
+ "# Afterwards we recalculate our number of training epochs\n",
701
+ "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
702
+ "\n",
703
+ "# We need to initialize the trackers we use, and also store our configuration.\n",
704
+ "# The trackers initializes automatically on the main process.\n",
705
+ "if accelerator.is_main_process:\n",
706
+ " accelerator.init_trackers(\"bark_coarse\", config={})\n",
707
+ "\n",
708
+ "# Train!\n",
709
+ "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
710
+ "logger.info(\"***** Running training *****\")\n",
711
+ "logger.info(f\" Num examples = {len(train_dataset)}\")\n",
712
+ "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
713
+ "logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
714
+ "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
715
+ "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
716
+ "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
717
+ "logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
718
+ "global_step = 0\n",
719
+ "first_epoch = 0\n",
720
+ "\n",
721
+ "if resume_from_checkpoint:\n",
722
+ " if resume_from_checkpoint != \"latest\":\n",
723
+ " path = os.path.basename(resume_from_checkpoint)\n",
724
+ " else:\n",
725
+ " # Get the most recent checkpoint\n",
726
+ " dirs = os.listdir(output_dir)\n",
727
+ " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
728
+ " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
729
+ " path = dirs[-1]\n",
730
+ " accelerator.print(f\"Resuming from checkpoint {path}\")\n",
731
+ " accelerator.load_state(os.path.join(output_dir, path))\n",
732
+ " global_step = int(path.split(\"-\")[1])\n",
733
+ "\n",
734
+ " resume_global_step = global_step * grad_accum\n",
735
+ " first_epoch = resume_global_step // num_update_steps_per_epoch\n",
736
+ " resume_step = resume_global_step % num_update_steps_per_epoch\n"
737
+ ]
738
+ },
739
+ {
740
+ "cell_type": "code",
741
+ "execution_count": null,
742
+ "metadata": {},
743
+ "outputs": [],
744
+ "source": [
745
+ "if accelerator.is_main_process:\n",
746
+ " model.eval()\n",
747
+ " validation_loss = 0.0\n",
748
+ " num_batches = 0\n",
749
+ " num_samples = 0\n",
750
+ " with torch.no_grad():\n",
751
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
752
+ " # Similar to training, process the validation batch\n",
753
+ " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n",
754
+ " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n",
755
+ " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n",
756
+ "\n",
757
+ " # Forward pass for validation\n",
758
+ " val_logits = model(val_inputs, training=True)\n",
759
+ " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n",
760
+ "\n",
761
+ " # Calculate the validation loss\n",
762
+ " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
763
+ " validation_loss += val_loss.item()\n",
764
+ " num_batches += 1\n",
765
+ " num_samples += val_batch['semantic_tokens'].size(0)\n",
766
+ "\n",
767
+ " average_validation_loss = validation_loss / num_batches\n",
768
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
769
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
770
+ ]
771
+ },
772
+ {
773
+ "attachments": {},
774
+ "cell_type": "markdown",
775
+ "metadata": {},
776
+ "source": [
777
+ "# Training"
778
+ ]
779
+ },
780
+ {
781
+ "cell_type": "code",
782
+ "execution_count": null,
783
+ "metadata": {},
784
+ "outputs": [],
785
+ "source": [
786
+ "# Only show the progress bar once on each machine.\n",
787
+ "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
788
+ "progress_bar.set_description(\"Steps\")\n",
789
+ "\n",
790
+ "for epoch in range(first_epoch, num_train_epochs):\n",
791
+ " model.train()\n",
792
+ " for step, batch in enumerate(train_dataloader):\n",
793
+ " # Skip steps until we reach the resumed step\n",
794
+ " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
795
+ " if step % grad_accum == 0:\n",
796
+ " progress_bar.update(1)\n",
797
+ " continue\n",
798
+ "\n",
799
+ " with accelerator.accumulate(model):\n",
800
+ " targets = batch['coarse_tokens'][:, 1:].contiguous()\n",
801
+ " \n",
802
+ " # Remove the last coarse token from the inputs since there is no target for it.\n",
803
+ " coarse_inputs = batch['coarse_tokens'][:, :-1]\n",
804
+ "\n",
805
+ " # Combine the semantic tokens and coarse tokens and feed them into the model.\n",
806
+ " inputs = torch.cat([batch['semantic_tokens'], coarse_inputs], dim=1)\n",
807
+ " logits = model(inputs, training=True)\n",
808
+ "\n",
809
+ " # We're only interested in the logits for the coarse tokens, so we ignore the logits for the input text tokens.\n",
810
+ " coarse_logits = logits[:, batch['semantic_tokens'].size(1):].contiguous()\n",
811
+ "\n",
812
+ " # Compute the loss.\n",
813
+ " loss = criterion(coarse_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n",
814
+ "\n",
815
+ " if semantic_cross_entropy_loss_weight > 0 and semantic_cross_entropy_loss_weight is not None:\n",
816
+ " semantic_logits = logits[:, :batch['semantic_tokens'].size(1)].contiguous()\n",
817
+ " semantic_loss = criterion(\n",
818
+ " semantic_logits.view(-1, model.config.input_vocab_size),\n",
819
+ " batch['semantic_tokens'].view(-1),\n",
820
+ " )\n",
821
+ " num_semantic_logits = semantic_logits.size(1)\n",
822
+ " num_coarse_logits = coarse_logits.size(1)\n",
823
+ " loss = (\n",
824
+ " semantic_loss * num_semantic_logits * semantic_cross_entropy_loss_weight +\n",
825
+ " loss * num_coarse_logits\n",
826
+ " ) / (num_semantic_logits + num_coarse_logits)\n",
827
+ "\n",
828
+ " accelerator.backward(loss)\n",
829
+ " if accelerator.sync_gradients:\n",
830
+ " params_to_clip = (\n",
831
+ " param for param in model.parameters() if param.requires_grad\n",
832
+ " )\n",
833
+ " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
834
+ " optimizer.step()\n",
835
+ " lr_scheduler.step()\n",
836
+ " optimizer.zero_grad()\n",
837
+ "\n",
838
+ " # Checks if the accelerator has performed an optimization step behind the scenes\n",
839
+ " if accelerator.sync_gradients:\n",
840
+ " progress_bar.update(1)\n",
841
+ " global_step += 1\n",
842
+ "\n",
843
+ " if global_step % checkpointing_steps == 0:\n",
844
+ " if accelerator.is_main_process:\n",
845
+ " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
846
+ " accelerator.save_state(save_path)\n",
847
+ " logger.info(f\"Saved state to {save_path}\")\n",
848
+ "\n",
849
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
850
+ " progress_bar.set_postfix(**logs)\n",
851
+ " accelerator.log(logs, step=global_step)\n",
852
+ "\n",
853
+ " if global_step >= max_train_steps:\n",
854
+ " break\n",
855
+ " \n",
856
+ " accelerator.wait_for_everyone()\n",
857
+ "\n",
858
+ "if accelerator.is_main_process:\n",
859
+ " if lora_dim > 0:\n",
860
+ " model = convert_lora_to_linear_layer(model)\n",
861
+ " # save model\n",
862
+ " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
863
+ " \n",
864
+ " config = model.config.__dict__\n",
865
+ " # save config\n",
866
+ " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
867
+ " json.dump(config, f, indent=2)\n",
868
+ "\n",
869
+ "accelerator.end_training()"
870
+ ]
871
+ },
872
+ {
873
+ "attachments": {},
874
+ "cell_type": "markdown",
875
+ "metadata": {},
876
+ "source": [
877
+ "# Validation"
878
+ ]
879
+ },
880
+ {
881
+ "cell_type": "code",
882
+ "execution_count": null,
883
+ "metadata": {},
884
+ "outputs": [],
885
+ "source": [
886
+ "if accelerator.is_main_process:\n",
887
+ " model.eval()\n",
888
+ " validation_loss = 0.0\n",
889
+ " num_batches = 0\n",
890
+ " num_samples = 0\n",
891
+ " with torch.no_grad():\n",
892
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
893
+ " # Similar to training, process the validation batch\n",
894
+ " val_targets = val_batch['coarse_tokens'][:, 1:].contiguous()\n",
895
+ " val_coarse_inputs = val_batch['coarse_tokens'][:, :-1]\n",
896
+ " val_inputs = torch.cat([val_batch['semantic_tokens'], val_coarse_inputs], dim=1)\n",
897
+ "\n",
898
+ " # Forward pass for validation\n",
899
+ " val_logits = model(val_inputs, training=True)\n",
900
+ " val_coarse_logits = val_logits[:, val_batch['semantic_tokens'].size(1):].contiguous()\n",
901
+ "\n",
902
+ " # Calculate the validation loss\n",
903
+ " val_loss = criterion(val_coarse_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
904
+ " validation_loss += val_loss.item()\n",
905
+ " num_batches += 1\n",
906
+ " num_samples += val_batch['semantic_tokens'].size(0)\n",
907
+ "\n",
908
+ " average_validation_loss = validation_loss / num_batches\n",
909
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
910
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
911
+ ]
912
+ }
913
+ ],
914
+ "metadata": {
915
+ "kernelspec": {
916
+ "display_name": "Python 3",
917
+ "language": "python",
918
+ "name": "python3"
919
+ },
920
+ "language_info": {
921
+ "codemirror_mode": {
922
+ "name": "ipython",
923
+ "version": 3
924
+ },
925
+ "file_extension": ".py",
926
+ "mimetype": "text/x-python",
927
+ "name": "python",
928
+ "nbconvert_exporter": "python",
929
+ "pygments_lexer": "ipython3",
930
+ "version": "3.10.8"
931
+ },
932
+ "orig_nbformat": 4
933
+ },
934
+ "nbformat": 4,
935
+ "nbformat_minor": 2
936
+ }
train_fine.ipynb ADDED
@@ -0,0 +1,919 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import os\n",
20
+ "import re\n",
21
+ "import gc\n",
22
+ "import json\n",
23
+ "import math\n",
24
+ "import hashlib\n",
25
+ "import numpy as np\n",
26
+ "import logging\n",
27
+ "import torchaudio\n",
28
+ "from tqdm.auto import tqdm\n",
29
+ "import torch.nn.functional as F\n",
30
+ "from encodec.utils import convert_audio\n",
31
+ "from accelerate import Accelerator\n",
32
+ "from accelerate.utils import set_seed\n",
33
+ "from transformers import BertTokenizer\n",
34
+ "from huggingface_hub import hf_hub_download\n",
35
+ "from packaging import version\n",
36
+ "from diffusers.optimization import get_scheduler\n",
37
+ "\n",
38
+ "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
39
+ "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
40
+ "from bark.model import GPTConfig, GPT\n",
41
+ "from bark.model_fine import FineGPT, FineGPTConfig"
42
+ ]
43
+ },
44
+ {
45
+ "attachments": {},
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "# Training Args"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "train_batch_size = 8\n",
59
+ "eval_batch_size = 8\n",
60
+ "grad_accum = 2\n",
61
+ "ckpt_path = 'models/fine_2.pt'\n",
62
+ "model_type = \"fine\"\n",
63
+ "dataset_path = 'datasets/joe_biden_state_of_union/'\n",
64
+ "logging_dir = 'logs/'\n",
65
+ "log_with = 'wandb'\n",
66
+ "hubert_path = 'data/models/hubert/hubert.pt'\n",
67
+ "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
68
+ "\n",
69
+ "output_dir = 'fine_output/'\n",
70
+ "resume_from_checkpoint = None\n",
71
+ "\n",
72
+ "checkpointing_steps = 1000\n",
73
+ "\n",
74
+ "mixed_precision = 'bf16'\n",
75
+ "bits = 16 #4 4 and 8 bit are a work in progress\n",
76
+ "compute_dtype = torch.bfloat16\n",
77
+ "double_quant = True\n",
78
+ "quant_type = 'nf4'\n",
79
+ "\n",
80
+ "lora_dim = 64\n",
81
+ "lora_scaling = 1\n",
82
+ "lora_dropout = 0.1\n",
83
+ "lora_module_name = 'transformer.h'\n",
84
+ "optimize_lora_params_only = False\n",
85
+ "\n",
86
+ "learning_rate = 1e-4\n",
87
+ "scale_lr = False\n",
88
+ "use_8bit_adam = False\n",
89
+ "adam_beta1 = 0.9\n",
90
+ "adam_beta2 = 0.999\n",
91
+ "adam_epsilon = 1e-8\n",
92
+ "weight_decay = 0.01\n",
93
+ "\n",
94
+ "llm_int8_skip_modules = None\n",
95
+ "keep_in_fp32_modules = ['lm_head']\n",
96
+ "\n",
97
+ "lr_scheduler_type = 'linear'\n",
98
+ "lr_warmup_steps = 60\n",
99
+ "num_train_epochs = 5\n",
100
+ "max_train_steps = None\n",
101
+ "max_grad_norm = 1.0\n",
102
+ "\n",
103
+ "seed = 741"
104
+ ]
105
+ },
106
+ {
107
+ "attachments": {},
108
+ "cell_type": "markdown",
109
+ "metadata": {},
110
+ "source": [
111
+ "# Define Functions"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "CONTEXT_WINDOW_SIZE = 1024\n",
121
+ "\n",
122
+ "MAX_SEMANTIC_LEN = 256\n",
123
+ "\n",
124
+ "SEMANTIC_RATE_HZ = 49.9\n",
125
+ "SEMANTIC_VOCAB_SIZE = 10_000\n",
126
+ "\n",
127
+ "TEXT_ENCODING_OFFSET = 10_048\n",
128
+ "SEMANTIC_PAD_TOKEN = 10_000\n",
129
+ "TEXT_PAD_TOKEN = 129_595\n",
130
+ "SEMANTIC_INFER_TOKEN = 129_599\n",
131
+ "\n",
132
+ "MAX_COARSE_LEN = 768\n",
133
+ "\n",
134
+ "SAMPLE_RATE = 24_000\n",
135
+ "CHANNELS = 1\n",
136
+ "\n",
137
+ "COARSE_SEMANTIC_PAD_TOKEN = 12_048\n",
138
+ "COARSE_INFER_TOKEN = 12_050\n",
139
+ "\n",
140
+ "CODEBOOK_SIZE = 1024\n",
141
+ "N_COARSE_CODEBOOKS = 2\n",
142
+ "N_FINE_CODEBOOKS = 8\n",
143
+ "COARSE_RATE_HZ = 75\n",
144
+ "\n",
145
+ "logger = logging.getLogger(__name__)\n",
146
+ "\n",
147
+ "\n",
148
+ "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
149
+ "\n",
150
+ "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
151
+ "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
152
+ "\n",
153
+ "\n",
154
+ "def _clear_cuda_cache():\n",
155
+ " if torch.cuda.is_available():\n",
156
+ " torch.cuda.empty_cache()\n",
157
+ " torch.cuda.synchronize()\n",
158
+ "\n",
159
+ "\n",
160
+ "def _md5(fname):\n",
161
+ " hash_md5 = hashlib.md5()\n",
162
+ " with open(fname, \"rb\") as f:\n",
163
+ " for chunk in iter(lambda: f.read(4096), b\"\"):\n",
164
+ " hash_md5.update(chunk)\n",
165
+ " return hash_md5.hexdigest()\n",
166
+ "\n",
167
+ "\n",
168
+ "def _download(from_hf_path, file_name, to_local_path):\n",
169
+ " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
170
+ " path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
171
+ " os.makedirs(path, exist_ok=True)\n",
172
+ " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
173
+ " os.replace(os.path.join(path, file_name), to_local_path)\n",
174
+ "\n",
175
+ "\n",
176
+ "def _tokenize(tokenizer, text):\n",
177
+ " return tokenizer.encode(text, add_special_tokens=False)\n",
178
+ "\n",
179
+ "\n",
180
+ "def _detokenize(tokenizer, enc_text):\n",
181
+ " return tokenizer.decode(enc_text)\n",
182
+ "\n",
183
+ "\n",
184
+ "def _normalize_whitespace(text):\n",
185
+ " return re.sub(r\"\\s+\", \" \", text).strip()\n",
186
+ "\n",
187
+ "\n",
188
+ "REMOTE_MODEL_PATHS = {\n",
189
+ " \"text_small\": {\n",
190
+ " \"repo_id\": \"suno/bark\",\n",
191
+ " \"file_name\": \"text.pt\",\n",
192
+ " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
193
+ " },\n",
194
+ " \"coarse_small\": {\n",
195
+ " \"repo_id\": \"suno/bark\",\n",
196
+ " \"file_name\": \"coarse.pt\",\n",
197
+ " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
198
+ " },\n",
199
+ " \"fine_small\": {\n",
200
+ " \"repo_id\": \"suno/bark\",\n",
201
+ " \"file_name\": \"fine.pt\",\n",
202
+ " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
203
+ " },\n",
204
+ " \"text\": {\n",
205
+ " \"repo_id\": \"suno/bark\",\n",
206
+ " \"file_name\": \"text_2.pt\",\n",
207
+ " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
208
+ " },\n",
209
+ " \"coarse\": {\n",
210
+ " \"repo_id\": \"suno/bark\",\n",
211
+ " \"file_name\": \"coarse_2.pt\",\n",
212
+ " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
213
+ " },\n",
214
+ " \"fine\": {\n",
215
+ " \"repo_id\": \"suno/bark\",\n",
216
+ " \"file_name\": \"fine_2.pt\",\n",
217
+ " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
218
+ " },\n",
219
+ "}\n",
220
+ "\n",
221
+ "\n",
222
+ "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
223
+ " if model_type == \"text\":\n",
224
+ " ConfigClass = GPTConfig\n",
225
+ " ModelClass = GPT\n",
226
+ " elif model_type == \"coarse\":\n",
227
+ " ConfigClass = GPTConfig\n",
228
+ " ModelClass = GPT\n",
229
+ " elif model_type == \"fine\":\n",
230
+ " ConfigClass = FineGPTConfig\n",
231
+ " ModelClass = FineGPT\n",
232
+ " else:\n",
233
+ " raise NotImplementedError()\n",
234
+ " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
235
+ " model_info = REMOTE_MODEL_PATHS[model_key]\n",
236
+ " if ckpt_path in [None, '']:\n",
237
+ " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
238
+ " if not os.path.exists(ckpt_path):\n",
239
+ " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
240
+ " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
241
+ " checkpoint = torch.load(ckpt_path, map_location=device)\n",
242
+ " # this is a hack\n",
243
+ " model_args = checkpoint[\"model_args\"]\n",
244
+ " if \"input_vocab_size\" not in model_args:\n",
245
+ " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
246
+ " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
247
+ " del model_args[\"vocab_size\"]\n",
248
+ " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
249
+ " model = ModelClass(gptconf)\n",
250
+ " state_dict = checkpoint[\"model\"]\n",
251
+ " # fixup checkpoint\n",
252
+ " unwanted_prefix = \"_orig_mod.\"\n",
253
+ " for k, v in list(state_dict.items()):\n",
254
+ " if k.startswith(unwanted_prefix):\n",
255
+ " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
256
+ " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
257
+ " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
258
+ " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
259
+ " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
260
+ " if len(extra_keys) != 0:\n",
261
+ " raise ValueError(f\"extra keys found: {extra_keys}\")\n",
262
+ " if len(missing_keys) != 0:\n",
263
+ " raise ValueError(f\"missing keys: {missing_keys}\")\n",
264
+ " model.load_state_dict(state_dict, strict=False)\n",
265
+ " n_params = model.get_num_params()\n",
266
+ " val_loss = checkpoint[\"best_val_loss\"].item()\n",
267
+ " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
268
+ " del checkpoint, state_dict\n",
269
+ " _clear_cuda_cache()\n",
270
+ " if model_type == \"text\":\n",
271
+ " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
272
+ " return model, tokenizer\n",
273
+ " return model\n",
274
+ "\n",
275
+ "\n",
276
+ "def _flatten_codebooks(arr, offset_size=CODEBOOK_SIZE):\n",
277
+ " assert len(arr.shape) == 2\n",
278
+ " arr = arr.copy()\n",
279
+ " if offset_size is not None:\n",
280
+ " for n in range(1, arr.shape[0]):\n",
281
+ " arr[n, :] += offset_size * n\n",
282
+ " flat_arr = arr.ravel(\"F\")\n",
283
+ " return flat_arr\n",
284
+ "\n",
285
+ "\n",
286
+ "def load_filepaths_and_text(filename, split=\"|\"):\n",
287
+ " with open(filename, encoding='utf-8', errors='ignore') as f:\n",
288
+ " filepaths_and_text = [line.strip().split(split) for line in f]\n",
289
+ " base = os.path.dirname(filename)\n",
290
+ " for j in range(len(filepaths_and_text)):\n",
291
+ " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
292
+ " return filepaths_and_text\n",
293
+ "\n",
294
+ "\n",
295
+ "class TtsDataset(torch.utils.data.Dataset):\n",
296
+ " def __init__(self, opt):\n",
297
+ " self.path = os.path.dirname(opt['path'])\n",
298
+ " self.mode = opt['mode']\n",
299
+ " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '.txt'))\n",
300
+ "\n",
301
+ " def __getitem__(self, index):\n",
302
+ " audiopath_and_text = self.audiopaths_and_text[index]\n",
303
+ " audiopath = audiopath_and_text[0]\n",
304
+ "\n",
305
+ " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
306
+ " fine_tokens = tokens['fine']\n",
307
+ "\n",
308
+ " return torch.from_numpy(fine_tokens)\n",
309
+ "\n",
310
+ " def __len__(self):\n",
311
+ " return len(self.audiopaths_and_text)\n",
312
+ "\n",
313
+ "\n",
314
+ "class TtsCollater():\n",
315
+ " def __init__(self):\n",
316
+ " pass\n",
317
+ " def __call__(self, batch):\n",
318
+ " max_len = 1024\n",
319
+ " fine_tokens = []\n",
320
+ "\n",
321
+ " for fine_tokens_ in batch:\n",
322
+ " if fine_tokens_.shape[1] > max_len:\n",
323
+ " start_idx = np.random.randint(0, fine_tokens_.shape[1] - max_len + 1)\n",
324
+ " fine_tokens_ = fine_tokens_[:, start_idx : start_idx + max_len]\n",
325
+ "\n",
326
+ " pad_size = max_len - fine_tokens_.shape[1]\n",
327
+ " fine_tokens_ = F.pad(fine_tokens_, (0, pad_size), value=CODEBOOK_SIZE)\n",
328
+ "\n",
329
+ " fine_tokens_ = fine_tokens_.T\n",
330
+ "\n",
331
+ " fine_tokens.append(fine_tokens_)\n",
332
+ "\n",
333
+ " return {'fine_tokens': torch.stack(fine_tokens).contiguous()}\n",
334
+ " \n",
335
+ "\n",
336
+ "accelerator = Accelerator(\n",
337
+ " gradient_accumulation_steps=grad_accum,\n",
338
+ " mixed_precision=mixed_precision,\n",
339
+ " log_with=log_with,\n",
340
+ " logging_dir=logging_dir,\n",
341
+ ")\n",
342
+ "device = accelerator.device\n",
343
+ "\n",
344
+ "os.makedirs(output_dir, exist_ok=True)\n",
345
+ "\n",
346
+ "set_seed(seed)"
347
+ ]
348
+ },
349
+ {
350
+ "attachments": {},
351
+ "cell_type": "markdown",
352
+ "metadata": {},
353
+ "source": [
354
+ "# Setup Dataset (only need to do this once)"
355
+ ]
356
+ },
357
+ {
358
+ "cell_type": "code",
359
+ "execution_count": null,
360
+ "metadata": {},
361
+ "outputs": [],
362
+ "source": [
363
+ "# max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
364
+ "\n",
365
+ "# path = dataset_path\n",
366
+ "\n",
367
+ "# # From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
368
+ "# from hubert.hubert_manager import HuBERTManager\n",
369
+ "# hubert_manager = HuBERTManager()\n",
370
+ "# from hubert.pre_kmeans_hubert import CustomHubert\n",
371
+ "# from hubert.customtokenizer import CustomTokenizer\n",
372
+ "# hubert_manager.make_sure_hubert_installed()\n",
373
+ "# hubert_manager.make_sure_tokenizer_installed()\n",
374
+ "\n",
375
+ "# # Load the HuBERT model\n",
376
+ "# hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
377
+ "# hubert_model.eval()\n",
378
+ "# for param in hubert_model.parameters():\n",
379
+ "# param.requires_grad = False\n",
380
+ "\n",
381
+ "# # Load the CustomTokenizer model\n",
382
+ "# hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
383
+ "\n",
384
+ "# from bark.generation import load_codec_model\n",
385
+ "# codec_model = load_codec_model(use_gpu=True)\n",
386
+ "# codec_model.eval()\n",
387
+ "# for param in codec_model.parameters():\n",
388
+ "# param.requires_grad = False\n",
389
+ "\n",
390
+ "\n",
391
+ "# def get_duration(wav, sr):\n",
392
+ "# return wav.shape[1] / sr\n",
393
+ "\n",
394
+ "# valid_lines_train = []\n",
395
+ "# # convert wavs to semantic tokens\n",
396
+ "# for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
397
+ "# wav, sr = torchaudio.load(wav_path)\n",
398
+ "# if not get_duration(wav, sr) > max_duration_sec:\n",
399
+ "# valid_lines_train.append((wav_path, txt))\n",
400
+ "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
401
+ "\n",
402
+ "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
403
+ "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
404
+ "\n",
405
+ "# # save semantic tokens\n",
406
+ "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
407
+ "# semantic_tokens = semantic_tokens.cpu().numpy()\n",
408
+ "\n",
409
+ "# # Extract discrete codes from EnCodec\n",
410
+ "# with torch.no_grad():\n",
411
+ "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
412
+ "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
413
+ "\n",
414
+ "# # move codes to cpu\n",
415
+ "# codes = codes.cpu().numpy()\n",
416
+ "\n",
417
+ "# # save tokens\n",
418
+ "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
419
+ "\n",
420
+ "# # rewrite train.txt with valid lines\n",
421
+ "# with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
422
+ "# for wav_path, txt in valid_lines_train:\n",
423
+ "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
424
+ "# f.write(f'{wav_path}|{txt}\\n')\n",
425
+ "\n",
426
+ "# valid_lines_valid = []\n",
427
+ "# for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
428
+ "# wav, sr = torchaudio.load(wav_path)\n",
429
+ "# if not get_duration(wav, sr) > max_duration_sec:\n",
430
+ "# valid_lines_valid.append((wav_path, txt))\n",
431
+ "# wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
432
+ "\n",
433
+ "# semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
434
+ "# semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
435
+ "\n",
436
+ "# # save semantic tokens\n",
437
+ "# os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
438
+ "# semantic_tokens = semantic_tokens.cpu().numpy()\n",
439
+ " \n",
440
+ "# # Extract discrete codes from EnCodec\n",
441
+ "# with torch.no_grad():\n",
442
+ "# encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
443
+ "# codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
444
+ "\n",
445
+ "# # move codes to cpu\n",
446
+ "# codes = codes.cpu().numpy()\n",
447
+ "\n",
448
+ "# # save tokens\n",
449
+ "# np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
450
+ "\n",
451
+ "# # rewrite valid.txt with valid lines\n",
452
+ "# with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
453
+ "# for wav_path, txt in valid_lines_valid:\n",
454
+ "# wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
455
+ "# f.write(f'{wav_path}|{txt}\\n')\n",
456
+ "\n",
457
+ "# del hubert_model\n",
458
+ "# del hubert_tokenizer\n",
459
+ "# del codec_model\n",
460
+ "# gc.collect()\n",
461
+ "# torch.cuda.empty_cache()"
462
+ ]
463
+ },
464
+ {
465
+ "attachments": {},
466
+ "cell_type": "markdown",
467
+ "metadata": {},
468
+ "source": [
469
+ "# Setup"
470
+ ]
471
+ },
472
+ {
473
+ "cell_type": "code",
474
+ "execution_count": null,
475
+ "metadata": {},
476
+ "outputs": [],
477
+ "source": [
478
+ "model = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
479
+ ]
480
+ },
481
+ {
482
+ "cell_type": "code",
483
+ "execution_count": null,
484
+ "metadata": {},
485
+ "outputs": [],
486
+ "source": [
487
+ "if scale_lr:\n",
488
+ " learning_rate = (\n",
489
+ " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
490
+ " )\n",
491
+ "\n",
492
+ "if use_8bit_adam:\n",
493
+ " try:\n",
494
+ " import bitsandbytes as bnb\n",
495
+ " except ImportError:\n",
496
+ " raise ImportError(\n",
497
+ " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
498
+ " )\n",
499
+ "\n",
500
+ " optimizer_class = bnb.optim.AdamW8bit\n",
501
+ "else:\n",
502
+ " optimizer_class = torch.optim.AdamW"
503
+ ]
504
+ },
505
+ {
506
+ "cell_type": "code",
507
+ "execution_count": null,
508
+ "metadata": {},
509
+ "outputs": [],
510
+ "source": [
511
+ "quantization_config=BitsAndBytesConfig(\n",
512
+ " load_in_4bit=bits == 4,\n",
513
+ " load_in_8bit=bits == 8,\n",
514
+ " llm_int8_threshold=6.0,\n",
515
+ " llm_int8_has_fp16_weight=False,\n",
516
+ " bnb_4bit_compute_dtype=compute_dtype,\n",
517
+ " bnb_4bit_use_double_quant=double_quant,\n",
518
+ " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
519
+ ")\n",
520
+ "\n",
521
+ "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
522
+ "# if quantization_config.load_in_8bit:\n",
523
+ "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
524
+ "# elif quantization_config.load_in_4bit:\n",
525
+ "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
526
+ "\n",
527
+ "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
528
+ "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
529
+ "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
530
+ "# else:\n",
531
+ "# modules_to_not_convert = llm_int8_skip_modules\n",
532
+ "\n",
533
+ "# if not isinstance(modules_to_not_convert, list):\n",
534
+ "# modules_to_not_convert = [modules_to_not_convert]\n",
535
+ "\n",
536
+ "# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
537
+ "\n",
538
+ "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
539
+ "\n",
540
+ "# if quantization_config.load_in_4bit and not supports_4bit:\n",
541
+ "# raise ValueError(\n",
542
+ "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
543
+ "# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
544
+ "# )\n",
545
+ " \n",
546
+ "# if len(modules_to_not_convert) == 0:\n",
547
+ "# modules_to_not_convert = None\n",
548
+ "\n",
549
+ "# model = replace_with_bnb_linear(\n",
550
+ "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
551
+ "# )\n",
552
+ "\n",
553
+ "# # training in 8-bit is only available in 0.37.0+\n",
554
+ "# model._is_kbit_training_enabled = version.parse(\n",
555
+ "# importlib_metadata.version(\"bitsandbytes\")\n",
556
+ "# ) >= version.parse(\"0.37.0\")\n",
557
+ "\n",
558
+ "# model.config.quantization_config = quantization_config"
559
+ ]
560
+ },
561
+ {
562
+ "cell_type": "code",
563
+ "execution_count": null,
564
+ "metadata": {},
565
+ "outputs": [],
566
+ "source": [
567
+ "if bits == 4:\n",
568
+ " from accelerate.utils import CustomDtype\n",
569
+ " target_dtype = CustomDtype.INT4\n",
570
+ "elif bits == 8:\n",
571
+ " target_dtype = torch.int8\n",
572
+ "\n",
573
+ "if lora_dim > 0:\n",
574
+ " for param in model.parameters():\n",
575
+ " if param.ndim == 1:\n",
576
+ " # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
577
+ " param.data = param.data.to(torch.float32)\n",
578
+ " \n",
579
+ " class CastOutputToFloat(nn.Sequential):\n",
580
+ " def forward(self, x):\n",
581
+ " return super().forward(x).to(torch.float32)\n",
582
+ "\n",
583
+ " # model.lm_head = CastOutputToFloat(model.lm_head)\n",
584
+ " for i, lm_head in enumerate(model.lm_heads):\n",
585
+ " model.lm_heads[i] = CastOutputToFloat(lm_head)\n",
586
+ "\n",
587
+ " model = convert_linear_layer_to_lora(model, lora_module_name,\n",
588
+ " lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
589
+ " lora_dropout=lora_dropout)\n",
590
+ " if optimize_lora_params_only:\n",
591
+ " model = only_optimize_lora_parameters(model)"
592
+ ]
593
+ },
594
+ {
595
+ "cell_type": "code",
596
+ "execution_count": null,
597
+ "metadata": {},
598
+ "outputs": [],
599
+ "source": [
600
+ "params_to_optimize = (\n",
601
+ " param for param in model.parameters() if param.requires_grad\n",
602
+ " )\n",
603
+ "\n",
604
+ "optimizer = optimizer_class(\n",
605
+ " params_to_optimize,\n",
606
+ " lr=learning_rate,\n",
607
+ " betas=(adam_beta1, adam_beta2),\n",
608
+ " weight_decay=weight_decay,\n",
609
+ " eps=adam_epsilon,\n",
610
+ ")"
611
+ ]
612
+ },
613
+ {
614
+ "cell_type": "code",
615
+ "execution_count": null,
616
+ "metadata": {},
617
+ "outputs": [],
618
+ "source": [
619
+ "opt_train = {\n",
620
+ " 'path': dataset_path,\n",
621
+ " 'mode': 'train',\n",
622
+ "}\n",
623
+ "\n",
624
+ "opt_val = {\n",
625
+ " 'path': dataset_path,\n",
626
+ " 'mode': 'valid',\n",
627
+ "}\n",
628
+ "\n",
629
+ "train_dataset = TtsDataset(opt_train)\n",
630
+ "validation_dataset = TtsDataset(opt_val)\n",
631
+ "\n",
632
+ "train_dataloader = torch.utils.data.DataLoader(\n",
633
+ " train_dataset,\n",
634
+ " batch_size=train_batch_size,\n",
635
+ " collate_fn=TtsCollater(),\n",
636
+ ")\n",
637
+ "\n",
638
+ "validation_dataloader = torch.utils.data.DataLoader(\n",
639
+ " validation_dataset,\n",
640
+ " batch_size=eval_batch_size,\n",
641
+ " collate_fn=TtsCollater(),\n",
642
+ ")\n",
643
+ "\n",
644
+ "criterion = torch.nn.CrossEntropyLoss(ignore_index=COARSE_SEMANTIC_PAD_TOKEN)\n",
645
+ "\n",
646
+ "# Scheduler and math around the number of training steps.\n",
647
+ "overrode_max_train_steps = False\n",
648
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
649
+ "if max_train_steps is None:\n",
650
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
651
+ " overrode_max_train_steps = True\n",
652
+ "\n",
653
+ "lr_scheduler = get_scheduler(\n",
654
+ " lr_scheduler_type,\n",
655
+ " optimizer=optimizer,\n",
656
+ " num_warmup_steps=lr_warmup_steps * grad_accum,\n",
657
+ " num_training_steps=max_train_steps * grad_accum,\n",
658
+ ")"
659
+ ]
660
+ },
661
+ {
662
+ "cell_type": "code",
663
+ "execution_count": null,
664
+ "metadata": {},
665
+ "outputs": [],
666
+ "source": [
667
+ "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
668
+ " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
669
+ ")\n",
670
+ "accelerator.register_for_checkpointing(lr_scheduler)\n",
671
+ "\n",
672
+ "weight_dtype = torch.float32\n",
673
+ "if accelerator.mixed_precision == \"fp16\":\n",
674
+ " weight_dtype = torch.float16\n",
675
+ "elif accelerator.mixed_precision == \"bf16\":\n",
676
+ " weight_dtype = torch.bfloat16"
677
+ ]
678
+ },
679
+ {
680
+ "cell_type": "code",
681
+ "execution_count": null,
682
+ "metadata": {},
683
+ "outputs": [],
684
+ "source": [
685
+ "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
686
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
687
+ "if overrode_max_train_steps:\n",
688
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
689
+ "# Afterwards we recalculate our number of training epochs\n",
690
+ "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
691
+ "\n",
692
+ "# We need to initialize the trackers we use, and also store our configuration.\n",
693
+ "# The trackers initializes automatically on the main process.\n",
694
+ "if accelerator.is_main_process:\n",
695
+ " accelerator.init_trackers(\"bark_coarse\", config={})\n",
696
+ "\n",
697
+ "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
698
+ "logger.info(\"***** Running training *****\")\n",
699
+ "logger.info(f\" Num examples = {len(train_dataset)}\")\n",
700
+ "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
701
+ "logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
702
+ "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
703
+ "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
704
+ "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
705
+ "logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
706
+ "global_step = 0\n",
707
+ "first_epoch = 0\n",
708
+ "\n",
709
+ "if resume_from_checkpoint:\n",
710
+ " if resume_from_checkpoint != \"latest\":\n",
711
+ " path = os.path.basename(resume_from_checkpoint)\n",
712
+ " else:\n",
713
+ " # Get the most recent checkpoint\n",
714
+ " dirs = os.listdir(output_dir)\n",
715
+ " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
716
+ " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
717
+ " path = dirs[-1]\n",
718
+ " accelerator.print(f\"Resuming from checkpoint {path}\")\n",
719
+ " accelerator.load_state(os.path.join(output_dir, path))\n",
720
+ " global_step = int(path.split(\"-\")[1])\n",
721
+ "\n",
722
+ " resume_global_step = global_step * grad_accum\n",
723
+ " first_epoch = resume_global_step // num_update_steps_per_epoch\n",
724
+ " resume_step = resume_global_step % num_update_steps_per_epoch\n"
725
+ ]
726
+ },
727
+ {
728
+ "cell_type": "code",
729
+ "execution_count": null,
730
+ "metadata": {},
731
+ "outputs": [],
732
+ "source": [
733
+ "if accelerator.is_main_process:\n",
734
+ " model.eval()\n",
735
+ " validation_loss = 0.0\n",
736
+ " num_batches = 0\n",
737
+ " num_samples = 0\n",
738
+ " with torch.no_grad():\n",
739
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
740
+ " # Similar to training, process the validation batch\n",
741
+ " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n",
742
+ " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n",
743
+ " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n",
744
+ " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n",
745
+ "\n",
746
+ " # Forward pass for validation\n",
747
+ " logits_7 = model(6, fine_tokens_input_7)\n",
748
+ " logits_8 = model(7, fine_tokens_input_8)\n",
749
+ "\n",
750
+ " # Calculate the validation loss\n",
751
+ " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
752
+ " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
753
+ "\n",
754
+ " loss = (loss_7 + loss_8) / 2\n",
755
+ " validation_loss += loss.item()\n",
756
+ " num_batches += 1\n",
757
+ " num_samples += val_batch['fine_tokens'].size(0)\n",
758
+ "\n",
759
+ " average_validation_loss = validation_loss / num_batches\n",
760
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
761
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
762
+ ]
763
+ },
764
+ {
765
+ "attachments": {},
766
+ "cell_type": "markdown",
767
+ "metadata": {},
768
+ "source": [
769
+ "# Training"
770
+ ]
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "execution_count": null,
775
+ "metadata": {},
776
+ "outputs": [],
777
+ "source": [
778
+ "# Only show the progress bar once on each machine.\n",
779
+ "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
780
+ "progress_bar.set_description(\"Steps\")\n",
781
+ "\n",
782
+ "for epoch in range(first_epoch, num_train_epochs):\n",
783
+ " model.train()\n",
784
+ " for step, batch in enumerate(train_dataloader):\n",
785
+ " # Skip steps until we reach the resumed step\n",
786
+ " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
787
+ " if step % grad_accum == 0:\n",
788
+ " progress_bar.update(1)\n",
789
+ " continue\n",
790
+ "\n",
791
+ " with accelerator.accumulate(model):\n",
792
+ " fine_targets_7 = batch['fine_tokens'][:, :, 6]\n",
793
+ " fine_tokens_input_7 = torch.cat([batch['fine_tokens'][:, :, :6], torch.zeros_like(batch['fine_tokens'][:, :, 6:])], dim=2)\n",
794
+ " fine_targets_8 = batch['fine_tokens'][:, :, 7]\n",
795
+ " fine_tokens_input_8 = torch.cat([batch['fine_tokens'][:, :, :7], torch.zeros_like(batch['fine_tokens'][:, :, 7:])], dim=2)\n",
796
+ "\n",
797
+ " # Forward pass\n",
798
+ " logits_7 = model(6, fine_tokens_input_7)\n",
799
+ " logits_8 = model(7, fine_tokens_input_8)\n",
800
+ "\n",
801
+ " # Calculate the loss\n",
802
+ " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
803
+ " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
804
+ "\n",
805
+ " loss = (loss_7 + loss_8) / 2\n",
806
+ "\n",
807
+ " accelerator.backward(loss)\n",
808
+ " if accelerator.sync_gradients:\n",
809
+ " params_to_clip = (\n",
810
+ " param for param in model.parameters() if param.requires_grad\n",
811
+ " )\n",
812
+ " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
813
+ " optimizer.step()\n",
814
+ " lr_scheduler.step()\n",
815
+ " optimizer.zero_grad()\n",
816
+ "\n",
817
+ " # Checks if the accelerator has performed an optimization step behind the scenes\n",
818
+ " if accelerator.sync_gradients:\n",
819
+ " progress_bar.update(1)\n",
820
+ " global_step += 1\n",
821
+ "\n",
822
+ " if global_step % checkpointing_steps == 0:\n",
823
+ " if accelerator.is_main_process:\n",
824
+ " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
825
+ " accelerator.save_state(save_path)\n",
826
+ " logger.info(f\"Saved state to {save_path}\")\n",
827
+ "\n",
828
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
829
+ " progress_bar.set_postfix(**logs)\n",
830
+ " accelerator.log(logs, step=global_step)\n",
831
+ "\n",
832
+ " if global_step >= max_train_steps:\n",
833
+ " break\n",
834
+ " \n",
835
+ " accelerator.wait_for_everyone()\n",
836
+ "\n",
837
+ "if accelerator.is_main_process:\n",
838
+ " if lora_dim > 0:\n",
839
+ " model = convert_lora_to_linear_layer(model)\n",
840
+ " # save model\n",
841
+ " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
842
+ " \n",
843
+ " config = model.config.__dict__\n",
844
+ " # save config\n",
845
+ " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
846
+ " json.dump(config, f, indent=2)\n",
847
+ "\n",
848
+ "accelerator.end_training()"
849
+ ]
850
+ },
851
+ {
852
+ "attachments": {},
853
+ "cell_type": "markdown",
854
+ "metadata": {},
855
+ "source": [
856
+ "# Validation"
857
+ ]
858
+ },
859
+ {
860
+ "cell_type": "code",
861
+ "execution_count": null,
862
+ "metadata": {},
863
+ "outputs": [],
864
+ "source": [
865
+ "if accelerator.is_main_process:\n",
866
+ " model.eval()\n",
867
+ " validation_loss = 0.0\n",
868
+ " num_batches = 0\n",
869
+ " num_samples = 0\n",
870
+ " with torch.no_grad():\n",
871
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
872
+ " # Similar to training, process the validation batch\n",
873
+ " fine_targets_7 = val_batch['fine_tokens'][:, :, 6]\n",
874
+ " fine_tokens_input_7 = torch.cat([val_batch['fine_tokens'][:, :, :6], torch.zeros_like(val_batch['fine_tokens'][:, :, 6:])], dim=2)\n",
875
+ " fine_targets_8 = val_batch['fine_tokens'][:, :, 7]\n",
876
+ " fine_tokens_input_8 = torch.cat([val_batch['fine_tokens'][:, :, :7], torch.zeros_like(val_batch['fine_tokens'][:, :, 7:])], dim=2)\n",
877
+ "\n",
878
+ " # Forward pass for validation\n",
879
+ " logits_7 = model(6, fine_tokens_input_7)\n",
880
+ " logits_8 = model(7, fine_tokens_input_8)\n",
881
+ "\n",
882
+ " # Calculate the validation loss\n",
883
+ " loss_7 = criterion(logits_7.view(-1, model.config.output_vocab_size), fine_targets_7.view(-1))\n",
884
+ " loss_8 = criterion(logits_8.view(-1, model.config.output_vocab_size), fine_targets_8.view(-1))\n",
885
+ "\n",
886
+ " loss = (loss_7 + loss_8) / 2\n",
887
+ " validation_loss += loss.item()\n",
888
+ " num_batches += 1\n",
889
+ " num_samples += val_batch['fine_tokens'].size(0)\n",
890
+ "\n",
891
+ " average_validation_loss = validation_loss / num_batches\n",
892
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
893
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
894
+ ]
895
+ }
896
+ ],
897
+ "metadata": {
898
+ "kernelspec": {
899
+ "display_name": "Python 3",
900
+ "language": "python",
901
+ "name": "python3"
902
+ },
903
+ "language_info": {
904
+ "codemirror_mode": {
905
+ "name": "ipython",
906
+ "version": 3
907
+ },
908
+ "file_extension": ".py",
909
+ "mimetype": "text/x-python",
910
+ "name": "python",
911
+ "nbconvert_exporter": "python",
912
+ "pygments_lexer": "ipython3",
913
+ "version": "3.10.8"
914
+ },
915
+ "orig_nbformat": 4
916
+ },
917
+ "nbformat": 4,
918
+ "nbformat_minor": 2
919
+ }
train_semantic.ipynb ADDED
@@ -0,0 +1,899 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "attachments": {},
5
+ "cell_type": "markdown",
6
+ "metadata": {},
7
+ "source": [
8
+ "# Imports"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "code",
13
+ "execution_count": null,
14
+ "metadata": {},
15
+ "outputs": [],
16
+ "source": [
17
+ "import torch\n",
18
+ "import torch.nn as nn\n",
19
+ "import os\n",
20
+ "import re\n",
21
+ "import gc\n",
22
+ "import json\n",
23
+ "import math\n",
24
+ "import hashlib\n",
25
+ "import numpy as np\n",
26
+ "import logging\n",
27
+ "import torchaudio\n",
28
+ "from tqdm.auto import tqdm\n",
29
+ "import torch.nn.functional as F\n",
30
+ "from encodec.utils import convert_audio\n",
31
+ "from accelerate import Accelerator\n",
32
+ "from accelerate.utils import set_seed\n",
33
+ "from transformers import BertTokenizer\n",
34
+ "from huggingface_hub import hf_hub_download\n",
35
+ "from packaging import version\n",
36
+ "from diffusers.optimization import get_scheduler\n",
37
+ "\n",
38
+ "from utils.bitsandbytes import BitsAndBytesConfig, importlib_metadata, get_keys_to_not_convert, replace_with_bnb_linear, set_module_quantized_tensor_to_device\n",
39
+ "from utils.lora import convert_linear_layer_to_lora, only_optimize_lora_parameters, convert_lora_to_linear_layer\n",
40
+ "from bark.model import GPTConfig, GPT\n",
41
+ "from bark.model_fine import FineGPT, FineGPTConfig"
42
+ ]
43
+ },
44
+ {
45
+ "attachments": {},
46
+ "cell_type": "markdown",
47
+ "metadata": {},
48
+ "source": [
49
+ "# Training Args"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "metadata": {},
56
+ "outputs": [],
57
+ "source": [
58
+ "train_batch_size = 8\n",
59
+ "eval_batch_size = 8\n",
60
+ "grad_accum = 2\n",
61
+ "ckpt_path = 'models/text_2.pt'\n",
62
+ "model_type = \"text\"\n",
63
+ "dataset_path = 'datasets/joe_biden_state_of_union/'\n",
64
+ "logging_dir = 'logs/'\n",
65
+ "log_with = 'wandb'\n",
66
+ "hubert_path = 'data/models/hubert/hubert.pt'\n",
67
+ "hubert_tokenizer_path = 'data/models/hubert/tokenizer.pth'\n",
68
+ "\n",
69
+ "output_dir = 'semantic_output/'\n",
70
+ "resume_from_checkpoint = None\n",
71
+ "\n",
72
+ "checkpointing_steps = 1000\n",
73
+ "\n",
74
+ "mixed_precision = 'bf16'\n",
75
+ "bits = 16 #4 4 and 8 bit are a work in progress\n",
76
+ "compute_dtype = torch.bfloat16\n",
77
+ "double_quant = True\n",
78
+ "quant_type = 'nf4'\n",
79
+ "\n",
80
+ "lora_dim = 64\n",
81
+ "lora_scaling = 1\n",
82
+ "lora_dropout = 0.1\n",
83
+ "lora_module_name = 'transformer.h'\n",
84
+ "optimize_lora_params_only = False\n",
85
+ "\n",
86
+ "learning_rate = 1e-4\n",
87
+ "scale_lr = False\n",
88
+ "use_8bit_adam = False\n",
89
+ "adam_beta1 = 0.9\n",
90
+ "adam_beta2 = 0.999\n",
91
+ "adam_epsilon = 1e-8\n",
92
+ "weight_decay = 0.01\n",
93
+ "\n",
94
+ "llm_int8_skip_modules = None\n",
95
+ "keep_in_fp32_modules = ['lm_head']\n",
96
+ "\n",
97
+ "lr_scheduler_type = 'linear'\n",
98
+ "lr_warmup_steps = 60\n",
99
+ "num_train_epochs = 5\n",
100
+ "max_train_steps = None\n",
101
+ "max_grad_norm = 1.0\n",
102
+ "\n",
103
+ "seed = 741"
104
+ ]
105
+ },
106
+ {
107
+ "attachments": {},
108
+ "cell_type": "markdown",
109
+ "metadata": {},
110
+ "source": [
111
+ "# Define Functions"
112
+ ]
113
+ },
114
+ {
115
+ "cell_type": "code",
116
+ "execution_count": null,
117
+ "metadata": {},
118
+ "outputs": [],
119
+ "source": [
120
+ "CONTEXT_WINDOW_SIZE = 1024\n",
121
+ "\n",
122
+ "MAX_TEXT_LEN = 256\n",
123
+ "\n",
124
+ "SEMANTIC_RATE_HZ = 49.9\n",
125
+ "SEMANTIC_VOCAB_SIZE = 10_000\n",
126
+ "\n",
127
+ "TEXT_ENCODING_OFFSET = 10_048\n",
128
+ "SEMANTIC_PAD_TOKEN = 10_000\n",
129
+ "TEXT_PAD_TOKEN = 129_595\n",
130
+ "SEMANTIC_INFER_TOKEN = 129_599\n",
131
+ "\n",
132
+ "MAX_SEMANTIC_LEN = 511\n",
133
+ "\n",
134
+ "SAMPLE_RATE = 24_000\n",
135
+ "CHANNELS = 1\n",
136
+ "\n",
137
+ "logger = logging.getLogger(__name__)\n",
138
+ "\n",
139
+ "\n",
140
+ "USE_SMALL_MODELS = os.environ.get(\"SERP_USE_SMALL_MODELS\", False)\n",
141
+ "\n",
142
+ "default_cache_dir = os.path.join(os.path.expanduser(\"~\"), \".cache\")\n",
143
+ "CACHE_DIR = os.path.join(os.getenv(\"XDG_CACHE_HOME\", default_cache_dir), \"serp\", \"bark_v0\")\n",
144
+ "\n",
145
+ "\n",
146
+ "def _clear_cuda_cache():\n",
147
+ " if torch.cuda.is_available():\n",
148
+ " torch.cuda.empty_cache()\n",
149
+ " torch.cuda.synchronize()\n",
150
+ "\n",
151
+ "\n",
152
+ "def _md5(fname):\n",
153
+ " hash_md5 = hashlib.md5()\n",
154
+ " with open(fname, \"rb\") as f:\n",
155
+ " for chunk in iter(lambda: f.read(4096), b\"\"):\n",
156
+ " hash_md5.update(chunk)\n",
157
+ " return hash_md5.hexdigest()\n",
158
+ "\n",
159
+ "\n",
160
+ "def _download(from_hf_path, file_name, to_local_path):\n",
161
+ " to_local_path = to_local_path.replace(\"\\\\\", \"/\")\n",
162
+ " path = '/'.join(to_local_path.split(\"/\")[:-1])\n",
163
+ " os.makedirs(path, exist_ok=True)\n",
164
+ " hf_hub_download(repo_id=from_hf_path, filename=file_name, local_dir=path)\n",
165
+ " os.replace(os.path.join(path, file_name), to_local_path)\n",
166
+ "\n",
167
+ "\n",
168
+ "def _tokenize(tokenizer, text):\n",
169
+ " return tokenizer.encode(text, add_special_tokens=False)\n",
170
+ "\n",
171
+ "\n",
172
+ "def _detokenize(tokenizer, enc_text):\n",
173
+ " return tokenizer.decode(enc_text)\n",
174
+ "\n",
175
+ "\n",
176
+ "def _normalize_whitespace(text):\n",
177
+ " return re.sub(r\"\\s+\", \" \", text).strip()\n",
178
+ "\n",
179
+ "\n",
180
+ "REMOTE_MODEL_PATHS = {\n",
181
+ " \"text_small\": {\n",
182
+ " \"repo_id\": \"suno/bark\",\n",
183
+ " \"file_name\": \"text.pt\",\n",
184
+ " \"checksum\": \"b3e42bcbab23b688355cd44128c4cdd3\",\n",
185
+ " },\n",
186
+ " \"coarse_small\": {\n",
187
+ " \"repo_id\": \"suno/bark\",\n",
188
+ " \"file_name\": \"coarse.pt\",\n",
189
+ " \"checksum\": \"5fe964825e3b0321f9d5f3857b89194d\",\n",
190
+ " },\n",
191
+ " \"fine_small\": {\n",
192
+ " \"repo_id\": \"suno/bark\",\n",
193
+ " \"file_name\": \"fine.pt\",\n",
194
+ " \"checksum\": \"5428d1befe05be2ba32195496e58dc90\",\n",
195
+ " },\n",
196
+ " \"text\": {\n",
197
+ " \"repo_id\": \"suno/bark\",\n",
198
+ " \"file_name\": \"text_2.pt\",\n",
199
+ " \"checksum\": \"54afa89d65e318d4f5f80e8e8799026a\",\n",
200
+ " },\n",
201
+ " \"coarse\": {\n",
202
+ " \"repo_id\": \"suno/bark\",\n",
203
+ " \"file_name\": \"coarse_2.pt\",\n",
204
+ " \"checksum\": \"8a98094e5e3a255a5c9c0ab7efe8fd28\",\n",
205
+ " },\n",
206
+ " \"fine\": {\n",
207
+ " \"repo_id\": \"suno/bark\",\n",
208
+ " \"file_name\": \"fine_2.pt\",\n",
209
+ " \"checksum\": \"59d184ed44e3650774a2f0503a48a97b\",\n",
210
+ " },\n",
211
+ "}\n",
212
+ "\n",
213
+ "\n",
214
+ "def _load_model(ckpt_path, device, use_small=False, model_type=\"text\"):\n",
215
+ " if model_type == \"text\":\n",
216
+ " ConfigClass = GPTConfig\n",
217
+ " ModelClass = GPT\n",
218
+ " elif model_type == \"coarse\":\n",
219
+ " ConfigClass = GPTConfig\n",
220
+ " ModelClass = GPT\n",
221
+ " elif model_type == \"fine\":\n",
222
+ " ConfigClass = FineGPTConfig\n",
223
+ " ModelClass = FineGPT\n",
224
+ " else:\n",
225
+ " raise NotImplementedError()\n",
226
+ " model_key = f\"{model_type}_small\" if use_small or USE_SMALL_MODELS else model_type\n",
227
+ " model_info = REMOTE_MODEL_PATHS[model_key]\n",
228
+ " if ckpt_path in [None, '']:\n",
229
+ " ckpt_path = os.path.join(CACHE_DIR, model_info[\"file_name\"])\n",
230
+ " if not os.path.exists(ckpt_path):\n",
231
+ " logger.info(f\"{model_type} model not found, downloading into `{CACHE_DIR}`.\")\n",
232
+ " _download(model_info[\"repo_id\"], model_info[\"file_name\"], ckpt_path)\n",
233
+ " checkpoint = torch.load(ckpt_path, map_location=device)\n",
234
+ " # this is a hack\n",
235
+ " model_args = checkpoint[\"model_args\"]\n",
236
+ " if \"input_vocab_size\" not in model_args:\n",
237
+ " model_args[\"input_vocab_size\"] = model_args[\"vocab_size\"]\n",
238
+ " model_args[\"output_vocab_size\"] = model_args[\"vocab_size\"]\n",
239
+ " del model_args[\"vocab_size\"]\n",
240
+ " gptconf = ConfigClass(**checkpoint[\"model_args\"])\n",
241
+ " model = ModelClass(gptconf)\n",
242
+ " state_dict = checkpoint[\"model\"]\n",
243
+ " # fixup checkpoint\n",
244
+ " unwanted_prefix = \"_orig_mod.\"\n",
245
+ " for k, v in list(state_dict.items()):\n",
246
+ " if k.startswith(unwanted_prefix):\n",
247
+ " state_dict[k[len(unwanted_prefix) :]] = state_dict.pop(k)\n",
248
+ " extra_keys = set(state_dict.keys()) - set(model.state_dict().keys())\n",
249
+ " extra_keys = set([k for k in extra_keys if not k.endswith(\".attn.bias\")])\n",
250
+ " missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())\n",
251
+ " missing_keys = set([k for k in missing_keys if not k.endswith(\".attn.bias\")])\n",
252
+ " if len(extra_keys) != 0:\n",
253
+ " raise ValueError(f\"extra keys found: {extra_keys}\")\n",
254
+ " if len(missing_keys) != 0:\n",
255
+ " raise ValueError(f\"missing keys: {missing_keys}\")\n",
256
+ " model.load_state_dict(state_dict, strict=False)\n",
257
+ " n_params = model.get_num_params()\n",
258
+ " val_loss = checkpoint[\"best_val_loss\"].item()\n",
259
+ " print(f\"Loaded {model_type} model with {n_params} params, val_loss={val_loss:.4f}.\")\n",
260
+ " del checkpoint, state_dict\n",
261
+ " _clear_cuda_cache()\n",
262
+ " if model_type == \"text\":\n",
263
+ " tokenizer = BertTokenizer.from_pretrained(\"bert-base-multilingual-cased\")\n",
264
+ " return model, tokenizer\n",
265
+ " return model\n",
266
+ "\n",
267
+ "\n",
268
+ "def load_filepaths_and_text(filename, split=\"|\"):\n",
269
+ " with open(filename, encoding='utf-8', errors='ignore') as f:\n",
270
+ " filepaths_and_text = [line.strip().split(split) for line in f]\n",
271
+ " base = os.path.dirname(filename)\n",
272
+ " for j in range(len(filepaths_and_text)):\n",
273
+ " filepaths_and_text[j][0] = os.path.join(base, filepaths_and_text[j][0])\n",
274
+ " return filepaths_and_text\n",
275
+ "\n",
276
+ "class TtsDataset(torch.utils.data.Dataset):\n",
277
+ " def __init__(self, opt):\n",
278
+ " self.path = os.path.dirname(opt['path'])\n",
279
+ " self.mode = opt['mode']\n",
280
+ " self.audiopaths_and_text = load_filepaths_and_text(os.path.join(opt['path'] , opt['mode'] + '_valid.txt'))\n",
281
+ " self.tokenizer = opt['tokenizer']\n",
282
+ "\n",
283
+ " def __getitem__(self, index):\n",
284
+ " audiopath_and_text = self.audiopaths_and_text[index]\n",
285
+ " audiopath, text = audiopath_and_text[0], audiopath_and_text[1]\n",
286
+ "\n",
287
+ " input_ids = np.array(_tokenize(self.tokenizer, text)) + TEXT_ENCODING_OFFSET\n",
288
+ " input_ids = torch.from_numpy(input_ids).long()\n",
289
+ " tokens = np.load(audiopath.replace('.wav', '.npz').replace('wavs', 'tokens'))\n",
290
+ " semantic_tokens = tokens['semantic']\n",
291
+ " semantic_tokens = torch.from_numpy(semantic_tokens).long()\n",
292
+ "\n",
293
+ " return input_ids, semantic_tokens\n",
294
+ "\n",
295
+ " def __len__(self):\n",
296
+ " return len(self.audiopaths_and_text)\n",
297
+ "\n",
298
+ "\n",
299
+ "class TtsCollater():\n",
300
+ " def __init__(self):\n",
301
+ " pass\n",
302
+ " def __call__(self, batch):\n",
303
+ " max_text_len = MAX_TEXT_LEN\n",
304
+ " max_semantic_tokens_len = MAX_SEMANTIC_LEN\n",
305
+ " texts = []\n",
306
+ " semantic_tokens = []\n",
307
+ "\n",
308
+ " for b in batch:\n",
309
+ " text, semantic_tokens_ = b\n",
310
+ " text = F.pad(text, (0, max_text_len-len(text)), value=TEXT_PAD_TOKEN)\n",
311
+ " semantic_history = torch.from_numpy(np.array([SEMANTIC_PAD_TOKEN] * 256))\n",
312
+ " text = torch.cat([text, semantic_history, torch.tensor([SEMANTIC_INFER_TOKEN])])\n",
313
+ " texts.append(text)\n",
314
+ " semantic_tokens_ = semantic_tokens_[:max_semantic_tokens_len]\n",
315
+ " semantic_tokens.append(F.pad(semantic_tokens_, (0, max_semantic_tokens_len-len(semantic_tokens_)), value=SEMANTIC_PAD_TOKEN))\n",
316
+ "\n",
317
+ " return {\n",
318
+ " 'input_ids': torch.stack(texts).contiguous(),\n",
319
+ " 'semantic_tokens': torch.stack(semantic_tokens).contiguous()\n",
320
+ " }\n",
321
+ " \n",
322
+ "\n",
323
+ "accelerator = Accelerator(\n",
324
+ " gradient_accumulation_steps=grad_accum,\n",
325
+ " mixed_precision=mixed_precision,\n",
326
+ " log_with=log_with,\n",
327
+ " logging_dir=logging_dir,\n",
328
+ ")\n",
329
+ "device = accelerator.device\n",
330
+ "\n",
331
+ "os.makedirs(output_dir, exist_ok=True)\n",
332
+ "\n",
333
+ "set_seed(seed)"
334
+ ]
335
+ },
336
+ {
337
+ "attachments": {},
338
+ "cell_type": "markdown",
339
+ "metadata": {},
340
+ "source": [
341
+ "# Setup Dataset (only need to do this once)"
342
+ ]
343
+ },
344
+ {
345
+ "cell_type": "code",
346
+ "execution_count": null,
347
+ "metadata": {},
348
+ "outputs": [],
349
+ "source": [
350
+ "max_duration_sec = 15.12 # the maximum allowed duration in seconds\n",
351
+ "\n",
352
+ "path = dataset_path\n",
353
+ "\n",
354
+ "# From https://github.com/gitmylo/bark-voice-cloning-HuBERT-quantizer\n",
355
+ "from hubert.hubert_manager import HuBERTManager\n",
356
+ "hubert_manager = HuBERTManager()\n",
357
+ "from hubert.pre_kmeans_hubert import CustomHubert\n",
358
+ "from hubert.customtokenizer import CustomTokenizer\n",
359
+ "hubert_manager.make_sure_hubert_installed()\n",
360
+ "hubert_manager.make_sure_tokenizer_installed()\n",
361
+ "\n",
362
+ "# Load the HuBERT model\n",
363
+ "hubert_model = CustomHubert(checkpoint_path=hubert_path).to(device)\n",
364
+ "hubert_model.eval()\n",
365
+ "for param in hubert_model.parameters():\n",
366
+ " param.requires_grad = False\n",
367
+ "\n",
368
+ "# Load the CustomTokenizer model\n",
369
+ "hubert_tokenizer = CustomTokenizer.load_from_checkpoint(hubert_tokenizer_path).to(device) # Automatically uses the right layers\n",
370
+ "\n",
371
+ "from bark.generation import load_codec_model\n",
372
+ "codec_model = load_codec_model(use_gpu=True)\n",
373
+ "codec_model.eval()\n",
374
+ "for param in codec_model.parameters():\n",
375
+ " param.requires_grad = False\n",
376
+ "\n",
377
+ "\n",
378
+ "def get_duration(wav, sr):\n",
379
+ " return wav.shape[1] / sr\n",
380
+ "\n",
381
+ "valid_lines_train = []\n",
382
+ "# convert wavs to semantic tokens\n",
383
+ "for wav_path, txt in load_filepaths_and_text(path + 'train.txt'):\n",
384
+ " wav, sr = torchaudio.load(wav_path)\n",
385
+ " if not get_duration(wav, sr) > max_duration_sec:\n",
386
+ " valid_lines_train.append((wav_path, txt))\n",
387
+ " wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
388
+ "\n",
389
+ " semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
390
+ " semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
391
+ "\n",
392
+ " # save semantic tokens\n",
393
+ " os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
394
+ " semantic_tokens = semantic_tokens.cpu().numpy()\n",
395
+ "\n",
396
+ " # Extract discrete codes from EnCodec\n",
397
+ " with torch.no_grad():\n",
398
+ " encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
399
+ " codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
400
+ "\n",
401
+ " # move codes to cpu\n",
402
+ " codes = codes.cpu().numpy()\n",
403
+ "\n",
404
+ " # save tokens\n",
405
+ " np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
406
+ "\n",
407
+ "# rewrite train.txt with valid lines\n",
408
+ "with open(path + 'train_valid.txt', 'w', encoding='utf-8') as f:\n",
409
+ " for wav_path, txt in valid_lines_train:\n",
410
+ " wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
411
+ " f.write(f'{wav_path}|{txt}\\n')\n",
412
+ "\n",
413
+ "valid_lines_valid = []\n",
414
+ "for wav_path, txt in load_filepaths_and_text(path + 'valid.txt'):\n",
415
+ " wav, sr = torchaudio.load(wav_path)\n",
416
+ " if not get_duration(wav, sr) > max_duration_sec:\n",
417
+ " valid_lines_valid.append((wav_path, txt))\n",
418
+ " wav = convert_audio(wav, sr, SAMPLE_RATE, CHANNELS).to(device)\n",
419
+ "\n",
420
+ " semantic_vectors = hubert_model.forward(wav, input_sample_hz=SAMPLE_RATE)\n",
421
+ " semantic_tokens = hubert_tokenizer.get_token(semantic_vectors)\n",
422
+ "\n",
423
+ " # save semantic tokens\n",
424
+ " os.makedirs(os.path.join(path, 'tokens'), exist_ok=True)\n",
425
+ " semantic_tokens = semantic_tokens.cpu().numpy()\n",
426
+ " \n",
427
+ " # Extract discrete codes from EnCodec\n",
428
+ " with torch.no_grad():\n",
429
+ " encoded_frames = codec_model.encode(wav.unsqueeze(0))\n",
430
+ " codes = torch.cat([encoded[0] for encoded in encoded_frames], dim=-1).squeeze() # [n_q, T]\n",
431
+ "\n",
432
+ " # move codes to cpu\n",
433
+ " codes = codes.cpu().numpy()\n",
434
+ "\n",
435
+ " # save tokens\n",
436
+ " np.savez_compressed(os.path.join(path, 'tokens', os.path.basename(wav_path).replace('.wav', '.npz')), fine=codes, coarse=codes[:2, :], semantic=semantic_tokens)\n",
437
+ "\n",
438
+ "# rewrite valid.txt with valid lines\n",
439
+ "with open(path + 'valid_valid.txt', 'w', encoding='utf-8') as f:\n",
440
+ " for wav_path, txt in valid_lines_valid:\n",
441
+ " wav_path = os.path.relpath(wav_path, dataset_path).replace('\\\\', '/')\n",
442
+ " f.write(f'{wav_path}|{txt}\\n')\n",
443
+ "\n",
444
+ "del hubert_model\n",
445
+ "del hubert_tokenizer\n",
446
+ "del codec_model\n",
447
+ "gc.collect()\n",
448
+ "torch.cuda.empty_cache()"
449
+ ]
450
+ },
451
+ {
452
+ "attachments": {},
453
+ "cell_type": "markdown",
454
+ "metadata": {},
455
+ "source": [
456
+ "# Setup"
457
+ ]
458
+ },
459
+ {
460
+ "cell_type": "code",
461
+ "execution_count": null,
462
+ "metadata": {},
463
+ "outputs": [],
464
+ "source": [
465
+ "model, tokenizer = _load_model(ckpt_path, device, use_small=False, model_type=model_type)"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": null,
471
+ "metadata": {},
472
+ "outputs": [],
473
+ "source": [
474
+ "if scale_lr:\n",
475
+ " learning_rate = (\n",
476
+ " learning_rate * grad_accum * train_batch_size * accelerator.num_processes\n",
477
+ " )\n",
478
+ "\n",
479
+ "if use_8bit_adam:\n",
480
+ " try:\n",
481
+ " import bitsandbytes as bnb\n",
482
+ " except ImportError:\n",
483
+ " raise ImportError(\n",
484
+ " \"To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`.\"\n",
485
+ " )\n",
486
+ "\n",
487
+ " optimizer_class = bnb.optim.AdamW8bit\n",
488
+ "else:\n",
489
+ " optimizer_class = torch.optim.AdamW"
490
+ ]
491
+ },
492
+ {
493
+ "cell_type": "code",
494
+ "execution_count": null,
495
+ "metadata": {},
496
+ "outputs": [],
497
+ "source": [
498
+ "quantization_config=BitsAndBytesConfig(\n",
499
+ " load_in_4bit=bits == 4,\n",
500
+ " load_in_8bit=bits == 8,\n",
501
+ " llm_int8_threshold=6.0,\n",
502
+ " llm_int8_has_fp16_weight=False,\n",
503
+ " bnb_4bit_compute_dtype=compute_dtype,\n",
504
+ " bnb_4bit_use_double_quant=double_quant,\n",
505
+ " bnb_4bit_quant_type=quant_type # {'fp4', 'nf4'}\n",
506
+ ")\n",
507
+ "\n",
508
+ "# if quantization_config.load_in_8bit or quantization_config.load_in_4bit:\n",
509
+ "# if quantization_config.load_in_8bit:\n",
510
+ "# logger.info(\"Detected 8-bit loading: activating 8-bit loading for this model\")\n",
511
+ "# elif quantization_config.load_in_4bit:\n",
512
+ "# logger.info(\"Detected 4-bit loading: activating 4-bit loading for this model\")\n",
513
+ "\n",
514
+ "# # We keep some modules such as the lm_head in their original dtype for numerical stability reasons\n",
515
+ "# if llm_int8_skip_modules is None or len(llm_int8_skip_modules) == 0:\n",
516
+ "# modules_to_not_convert = [] # get_keys_to_not_convert(model)\n",
517
+ "# else:\n",
518
+ "# modules_to_not_convert = llm_int8_skip_modules\n",
519
+ "\n",
520
+ "# if not isinstance(modules_to_not_convert, list):\n",
521
+ "# modules_to_not_convert = [modules_to_not_convert]\n",
522
+ "\n",
523
+ "# modules_to_not_convert.extend(keep_in_fp32_modules)\n",
524
+ "\n",
525
+ "# supports_4bit = version.parse(importlib_metadata.version(\"bitsandbytes\")) >= version.parse(\"0.39.0\")\n",
526
+ "\n",
527
+ "# if quantization_config.load_in_4bit and not supports_4bit:\n",
528
+ "# raise ValueError(\n",
529
+ "# \"You have a version of `bitsandbytes` that is not compatible with 4bit inference and training\"\n",
530
+ "# \" make sure you have the latest version of `bitsandbytes` installed\"\n",
531
+ "# )\n",
532
+ " \n",
533
+ "# if len(modules_to_not_convert) == 0:\n",
534
+ "# modules_to_not_convert = None\n",
535
+ "\n",
536
+ "# model = replace_with_bnb_linear(\n",
537
+ "# model, modules_to_not_convert=modules_to_not_convert, quantization_config=quantization_config\n",
538
+ "# )\n",
539
+ "\n",
540
+ "# # training in 8-bit is only available in 0.37.0+\n",
541
+ "# model._is_kbit_training_enabled = version.parse(\n",
542
+ "# importlib_metadata.version(\"bitsandbytes\")\n",
543
+ "# ) >= version.parse(\"0.37.0\")\n",
544
+ "\n",
545
+ "# model.config.quantization_config = quantization_config"
546
+ ]
547
+ },
548
+ {
549
+ "cell_type": "code",
550
+ "execution_count": null,
551
+ "metadata": {},
552
+ "outputs": [],
553
+ "source": [
554
+ "if bits == 4:\n",
555
+ " from accelerate.utils import CustomDtype\n",
556
+ " target_dtype = CustomDtype.INT4\n",
557
+ "elif bits == 8:\n",
558
+ " target_dtype = torch.int8\n",
559
+ "\n",
560
+ "if lora_dim > 0:\n",
561
+ " for param in model.parameters():\n",
562
+ " if param.ndim == 1:\n",
563
+ " # cast the small parameters (e.g. layernorm) to fp32 for stability\n",
564
+ " param.data = param.data.to(torch.float32)\n",
565
+ " \n",
566
+ " class CastOutputToFloat(nn.Sequential):\n",
567
+ " def forward(self, x):\n",
568
+ " return super().forward(x).to(torch.float32)\n",
569
+ "\n",
570
+ " model.lm_head = CastOutputToFloat(model.lm_head)\n",
571
+ "\n",
572
+ " model = convert_linear_layer_to_lora(model, lora_module_name,\n",
573
+ " lora_dim=lora_dim, lora_scaling=lora_scaling,\n",
574
+ " lora_dropout=lora_dropout)\n",
575
+ " if optimize_lora_params_only:\n",
576
+ " model = only_optimize_lora_parameters(model)"
577
+ ]
578
+ },
579
+ {
580
+ "cell_type": "code",
581
+ "execution_count": null,
582
+ "metadata": {},
583
+ "outputs": [],
584
+ "source": [
585
+ "params_to_optimize = (\n",
586
+ " param for param in model.parameters() if param.requires_grad\n",
587
+ " )\n",
588
+ "\n",
589
+ "optimizer = optimizer_class(\n",
590
+ " params_to_optimize,\n",
591
+ " lr=learning_rate,\n",
592
+ " betas=(adam_beta1, adam_beta2),\n",
593
+ " weight_decay=weight_decay,\n",
594
+ " eps=adam_epsilon,\n",
595
+ ")"
596
+ ]
597
+ },
598
+ {
599
+ "cell_type": "code",
600
+ "execution_count": null,
601
+ "metadata": {},
602
+ "outputs": [],
603
+ "source": [
604
+ "opt_train = {\n",
605
+ " 'path': dataset_path,\n",
606
+ " 'tokenizer': tokenizer,\n",
607
+ " 'mode': 'train',\n",
608
+ "}\n",
609
+ "\n",
610
+ "opt_val = {\n",
611
+ " 'path': dataset_path,\n",
612
+ " 'tokenizer': tokenizer,\n",
613
+ " 'mode': 'valid',\n",
614
+ "}\n",
615
+ "\n",
616
+ "train_dataset = TtsDataset(opt_train)\n",
617
+ "validation_dataset = TtsDataset(opt_val)\n",
618
+ "\n",
619
+ "train_dataloader = torch.utils.data.DataLoader(\n",
620
+ " train_dataset,\n",
621
+ " batch_size=train_batch_size,\n",
622
+ " collate_fn=TtsCollater(),\n",
623
+ ")\n",
624
+ "\n",
625
+ "validation_dataloader = torch.utils.data.DataLoader(\n",
626
+ " validation_dataset,\n",
627
+ " batch_size=eval_batch_size,\n",
628
+ " collate_fn=TtsCollater(),\n",
629
+ ")\n",
630
+ "\n",
631
+ "criterion = torch.nn.CrossEntropyLoss() #ignore_index=SEMANTIC_PAD_TOKEN)\n",
632
+ "\n",
633
+ "# Scheduler and math around the number of training steps.\n",
634
+ "overrode_max_train_steps = False\n",
635
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
636
+ "if max_train_steps is None:\n",
637
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
638
+ " overrode_max_train_steps = True\n",
639
+ "\n",
640
+ "lr_scheduler = get_scheduler(\n",
641
+ " lr_scheduler_type,\n",
642
+ " optimizer=optimizer,\n",
643
+ " num_warmup_steps=lr_warmup_steps * grad_accum,\n",
644
+ " num_training_steps=max_train_steps * grad_accum,\n",
645
+ ")"
646
+ ]
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "execution_count": null,
651
+ "metadata": {},
652
+ "outputs": [],
653
+ "source": [
654
+ "model, optimizer, train_dataloader, validation_dataloader, lr_scheduler = accelerator.prepare(\n",
655
+ " model, optimizer, train_dataloader, validation_dataloader, lr_scheduler\n",
656
+ ")\n",
657
+ "accelerator.register_for_checkpointing(lr_scheduler)\n",
658
+ "\n",
659
+ "weight_dtype = torch.float32\n",
660
+ "if accelerator.mixed_precision == \"fp16\":\n",
661
+ " weight_dtype = torch.float16\n",
662
+ "elif accelerator.mixed_precision == \"bf16\":\n",
663
+ " weight_dtype = torch.bfloat16"
664
+ ]
665
+ },
666
+ {
667
+ "cell_type": "code",
668
+ "execution_count": null,
669
+ "metadata": {},
670
+ "outputs": [],
671
+ "source": [
672
+ "# We need to recalculate our total training steps as the size of the training dataloader may have changed.\n",
673
+ "num_update_steps_per_epoch = math.ceil(len(train_dataloader) / grad_accum)\n",
674
+ "if overrode_max_train_steps:\n",
675
+ " max_train_steps = num_train_epochs * num_update_steps_per_epoch\n",
676
+ "# Afterwards we recalculate our number of training epochs\n",
677
+ "num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)\n",
678
+ "\n",
679
+ "# We need to initialize the trackers we use, and also store our configuration.\n",
680
+ "# The trackers initializes automatically on the main process.\n",
681
+ "if accelerator.is_main_process:\n",
682
+ " accelerator.init_trackers(\"bark_semantic\", config={})\n",
683
+ "\n",
684
+ "# Train!\n",
685
+ "total_batch_size = train_batch_size * accelerator.num_processes * grad_accum\n",
686
+ "logger.info(\"***** Running training *****\")\n",
687
+ "logger.info(f\" Num examples = {len(train_dataset)}\")\n",
688
+ "logger.info(f\" Num batches each epoch = {len(train_dataloader)}\")\n",
689
+ "logger.info(f\" Num Epochs = {num_train_epochs}\")\n",
690
+ "logger.info(f\" Instantaneous batch size per device = {train_batch_size}\")\n",
691
+ "logger.info(f\" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}\")\n",
692
+ "logger.info(f\" Gradient Accumulation steps = {grad_accum}\")\n",
693
+ "logger.info(f\" Total optimization steps = {max_train_steps}\")\n",
694
+ "global_step = 0\n",
695
+ "first_epoch = 0\n",
696
+ "\n",
697
+ "if resume_from_checkpoint:\n",
698
+ " if resume_from_checkpoint != \"latest\":\n",
699
+ " path = os.path.basename(resume_from_checkpoint)\n",
700
+ " else:\n",
701
+ " # Get the most recent checkpoint\n",
702
+ " dirs = os.listdir(output_dir)\n",
703
+ " dirs = [d for d in dirs if d.startswith(\"checkpoint\")]\n",
704
+ " dirs = sorted(dirs, key=lambda x: int(x.split(\"-\")[1]))\n",
705
+ " path = dirs[-1]\n",
706
+ " accelerator.print(f\"Resuming from checkpoint {path}\")\n",
707
+ " accelerator.load_state(os.path.join(output_dir, path))\n",
708
+ " global_step = int(path.split(\"-\")[1])\n",
709
+ "\n",
710
+ " resume_global_step = global_step * grad_accum\n",
711
+ " first_epoch = resume_global_step // num_update_steps_per_epoch\n",
712
+ " resume_step = resume_global_step % num_update_steps_per_epoch\n"
713
+ ]
714
+ },
715
+ {
716
+ "cell_type": "code",
717
+ "execution_count": null,
718
+ "metadata": {},
719
+ "outputs": [],
720
+ "source": [
721
+ "if accelerator.is_main_process:\n",
722
+ " model.eval()\n",
723
+ " validation_loss = 0.0\n",
724
+ " num_batches = 0\n",
725
+ " num_samples = 0\n",
726
+ " with torch.no_grad():\n",
727
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
728
+ " # Similar to training, process the validation batch\n",
729
+ " val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
730
+ " val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
731
+ " val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
732
+ "\n",
733
+ " # Forward pass for validation\n",
734
+ " val_logits = model(val_inputs, training=True)\n",
735
+ " val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
736
+ "\n",
737
+ " # Calculate the validation loss\n",
738
+ " val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
739
+ " validation_loss += val_loss.item()\n",
740
+ " num_batches += 1\n",
741
+ " num_samples += val_batch['input_ids'].size(0)\n",
742
+ "\n",
743
+ " average_validation_loss = validation_loss / num_batches\n",
744
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
745
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
746
+ ]
747
+ },
748
+ {
749
+ "attachments": {},
750
+ "cell_type": "markdown",
751
+ "metadata": {},
752
+ "source": [
753
+ "# Training"
754
+ ]
755
+ },
756
+ {
757
+ "cell_type": "code",
758
+ "execution_count": null,
759
+ "metadata": {},
760
+ "outputs": [],
761
+ "source": [
762
+ "# Only show the progress bar once on each machine.\n",
763
+ "progress_bar = tqdm(range(global_step, max_train_steps), disable=not accelerator.is_local_main_process)\n",
764
+ "progress_bar.set_description(\"Steps\")\n",
765
+ "\n",
766
+ "for epoch in range(first_epoch, num_train_epochs):\n",
767
+ " model.train()\n",
768
+ " for step, batch in enumerate(train_dataloader):\n",
769
+ " # Skip steps until we reach the resumed step\n",
770
+ " if resume_from_checkpoint and epoch == first_epoch and step < resume_step:\n",
771
+ " if step % grad_accum == 0:\n",
772
+ " progress_bar.update(1)\n",
773
+ " continue\n",
774
+ "\n",
775
+ " with accelerator.accumulate(model):\n",
776
+ " targets = batch['semantic_tokens'][:, 1:].contiguous()\n",
777
+ " \n",
778
+ " # Remove the last semantic token from the inputs since there is no target for it.\n",
779
+ " semantic_inputs = batch['semantic_tokens'][:, :-1]\n",
780
+ "\n",
781
+ " # Combine the text and semantic tokens and feed them into the model.\n",
782
+ " inputs = torch.cat([batch['input_ids'], semantic_inputs], dim=1)\n",
783
+ " logits = model(inputs, training=True)\n",
784
+ "\n",
785
+ " # We're only interested in the logits for the semantic tokens, so we ignore the logits for the input text tokens.\n",
786
+ " semantic_logits = logits[:, batch['input_ids'].size(1):].contiguous()\n",
787
+ "\n",
788
+ " # Compute the loss.\n",
789
+ " loss = criterion(semantic_logits.view(-1, model.config.output_vocab_size), targets.view(-1))\n",
790
+ "\n",
791
+ " accelerator.backward(loss)\n",
792
+ " if accelerator.sync_gradients:\n",
793
+ " params_to_clip = (\n",
794
+ " param for param in model.parameters() if param.requires_grad\n",
795
+ " )\n",
796
+ " accelerator.clip_grad_norm_(params_to_clip, max_grad_norm)\n",
797
+ " optimizer.step()\n",
798
+ " lr_scheduler.step()\n",
799
+ " optimizer.zero_grad()\n",
800
+ "\n",
801
+ " # Checks if the accelerator has performed an optimization step behind the scenes\n",
802
+ " if accelerator.sync_gradients:\n",
803
+ " progress_bar.update(1)\n",
804
+ " global_step += 1\n",
805
+ "\n",
806
+ " if global_step % checkpointing_steps == 0:\n",
807
+ " if accelerator.is_main_process:\n",
808
+ " save_path = os.path.join(output_dir, f\"checkpoint-{global_step}\")\n",
809
+ " accelerator.save_state(save_path)\n",
810
+ " logger.info(f\"Saved state to {save_path}\")\n",
811
+ "\n",
812
+ " logs = {\"loss\": loss.detach().item(), \"lr\": lr_scheduler.get_last_lr()[0]}\n",
813
+ " progress_bar.set_postfix(**logs)\n",
814
+ " accelerator.log(logs, step=global_step)\n",
815
+ "\n",
816
+ " if global_step >= max_train_steps:\n",
817
+ " break\n",
818
+ " \n",
819
+ " accelerator.wait_for_everyone()\n",
820
+ "\n",
821
+ "if accelerator.is_main_process:\n",
822
+ " if lora_dim > 0:\n",
823
+ " model = convert_lora_to_linear_layer(model)\n",
824
+ " # save model\n",
825
+ " accelerator.save(model.state_dict(), os.path.join(output_dir, \"pytorch_model.bin\"))\n",
826
+ "\n",
827
+ " config = model.config.__dict__\n",
828
+ " # save config\n",
829
+ " with open(os.path.join(output_dir, \"config.json\"), \"w\") as f:\n",
830
+ " json.dump(config, f, indent=2)\n",
831
+ "\n",
832
+ "accelerator.end_training()"
833
+ ]
834
+ },
835
+ {
836
+ "attachments": {},
837
+ "cell_type": "markdown",
838
+ "metadata": {},
839
+ "source": [
840
+ "# Validation"
841
+ ]
842
+ },
843
+ {
844
+ "cell_type": "code",
845
+ "execution_count": null,
846
+ "metadata": {},
847
+ "outputs": [],
848
+ "source": [
849
+ "if accelerator.is_main_process:\n",
850
+ " model.eval()\n",
851
+ " validation_loss = 0.0\n",
852
+ " num_batches = 0\n",
853
+ " num_samples = 0\n",
854
+ " with torch.no_grad():\n",
855
+ " for val_step, val_batch in enumerate(validation_dataloader):\n",
856
+ " # Similar to training, process the validation batch\n",
857
+ " val_targets = val_batch['semantic_tokens'][:, 1:].contiguous()\n",
858
+ " val_semantic_inputs = val_batch['semantic_tokens'][:, :-1]\n",
859
+ " val_inputs = torch.cat([val_batch['input_ids'], val_semantic_inputs], dim=1)\n",
860
+ "\n",
861
+ " # Forward pass for validation\n",
862
+ " val_logits = model(val_inputs, training=True)\n",
863
+ " val_semantic_logits = val_logits[:, val_batch['input_ids'].size(1):].contiguous()\n",
864
+ "\n",
865
+ " # Calculate the validation loss\n",
866
+ " val_loss = criterion(val_semantic_logits.view(-1, model.config.output_vocab_size), val_targets.view(-1))\n",
867
+ " validation_loss += val_loss.item()\n",
868
+ " num_batches += 1\n",
869
+ " num_samples += val_batch['input_ids'].size(0)\n",
870
+ "\n",
871
+ " average_validation_loss = validation_loss / num_batches\n",
872
+ " logger.info(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")\n",
873
+ " print(f\"Validation Loss: {average_validation_loss} over {num_samples} samples and {num_batches} batches.\")"
874
+ ]
875
+ }
876
+ ],
877
+ "metadata": {
878
+ "kernelspec": {
879
+ "display_name": "Python 3",
880
+ "language": "python",
881
+ "name": "python3"
882
+ },
883
+ "language_info": {
884
+ "codemirror_mode": {
885
+ "name": "ipython",
886
+ "version": 3
887
+ },
888
+ "file_extension": ".py",
889
+ "mimetype": "text/x-python",
890
+ "name": "python",
891
+ "nbconvert_exporter": "python",
892
+ "pygments_lexer": "ipython3",
893
+ "version": "3.10.8"
894
+ },
895
+ "orig_nbformat": 4
896
+ },
897
+ "nbformat": 4,
898
+ "nbformat_minor": 2
899
+ }