AnhP commited on
Commit
76a7c22
·
verified ·
1 Parent(s): b5a1f0d

Delete main

Browse files
Files changed (41) hide show
  1. main/app/app.py +0 -0
  2. main/app/clean.py +0 -40
  3. main/app/run_tensorboard.py +0 -30
  4. main/app/sync.py +0 -80
  5. main/configs/config.json +0 -229
  6. main/configs/config.py +0 -120
  7. main/configs/v1/32000.json +0 -82
  8. main/configs/v1/40000.json +0 -80
  9. main/configs/v1/48000.json +0 -82
  10. main/configs/v2/32000.json +0 -76
  11. main/configs/v2/40000.json +0 -76
  12. main/configs/v2/48000.json +0 -76
  13. main/inference/audio_effects.py +0 -170
  14. main/inference/convert.py +0 -1060
  15. main/inference/create_dataset.py +0 -370
  16. main/inference/create_index.py +0 -120
  17. main/inference/extract.py +0 -450
  18. main/inference/preprocess.py +0 -360
  19. main/inference/separator_music.py +0 -400
  20. main/inference/train.py +0 -1600
  21. main/library/algorithm/commons.py +0 -100
  22. main/library/algorithm/modules.py +0 -80
  23. main/library/algorithm/residuals.py +0 -170
  24. main/library/algorithm/separator.py +0 -420
  25. main/library/algorithm/synthesizers.py +0 -590
  26. main/library/architectures/demucs_separator.py +0 -340
  27. main/library/architectures/mdx_separator.py +0 -370
  28. main/library/predictors/FCPE.py +0 -600
  29. main/library/predictors/RMVPE.py +0 -270
  30. main/library/uvr5_separator/common_separator.py +0 -270
  31. main/library/uvr5_separator/demucs/apply.py +0 -280
  32. main/library/uvr5_separator/demucs/demucs.py +0 -340
  33. main/library/uvr5_separator/demucs/hdemucs.py +0 -850
  34. main/library/uvr5_separator/demucs/htdemucs.py +0 -690
  35. main/library/uvr5_separator/demucs/states.py +0 -70
  36. main/library/uvr5_separator/demucs/utils.py +0 -10
  37. main/library/uvr5_separator/spec_utils.py +0 -1100
  38. main/tools/gdown.py +0 -230
  39. main/tools/mediafire.py +0 -30
  40. main/tools/meganz.py +0 -180
  41. main/tools/pixeldrain.py +0 -19
main/app/app.py DELETED
The diff for this file is too large to render. See raw diff
 
main/app/clean.py DELETED
@@ -1,40 +0,0 @@
1
- import time
2
- import threading
3
-
4
- from googleapiclient.discovery import build
5
-
6
-
7
- class Clean:
8
- def __init__(self, every=300):
9
- self.service = build('drive', 'v3')
10
- self.every = every
11
- self.trash_cleanup_thread = None
12
-
13
- def delete(self):
14
- page_token = None
15
-
16
- while 1:
17
- response = self.service.files().list(q="trashed=true", spaces='drive', fields="nextPageToken, files(id, name)", pageToken=page_token).execute()
18
-
19
- for file in response.get('files', []):
20
- if file['name'].startswith("G_") or file['name'].startswith("D_"):
21
- try:
22
- self.service.files().delete(fileId=file['id']).execute()
23
- except Exception as e:
24
- raise RuntimeError(e)
25
-
26
- page_token = response.get('nextPageToken', None)
27
- if page_token is None: break
28
-
29
- def clean(self):
30
- while 1:
31
- self.delete()
32
- time.sleep(self.every)
33
-
34
- def start(self):
35
- self.trash_cleanup_thread = threading.Thread(target=self.clean)
36
- self.trash_cleanup_thread.daemon = True
37
- self.trash_cleanup_thread.start()
38
-
39
- def stop(self):
40
- if self.trash_cleanup_thread: self.trash_cleanup_thread.join()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/run_tensorboard.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import sys
3
- import json
4
- import logging
5
- import webbrowser
6
-
7
- from tensorboard import program
8
-
9
- sys.path.append(os.getcwd())
10
-
11
- from main.configs.config import Config
12
- translations = Config().translations
13
-
14
- with open(os.path.join("main", "configs", "config.json"), "r") as f:
15
- configs = json.load(f)
16
-
17
- def launch_tensorboard_pipeline():
18
- logging.getLogger("root").setLevel(logging.ERROR)
19
- logging.getLogger("tensorboard").setLevel(logging.ERROR)
20
-
21
- tb = program.TensorBoard()
22
- tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs["tensorboard_port"]}"])
23
- url = tb.launch()
24
-
25
- print(f"{translations['tensorboard_url']}: {url}")
26
- webbrowser.open(url)
27
-
28
- return f"{translations['tensorboard_url']}: {url}"
29
-
30
- if __name__ == "__main__": launch_tensorboard_pipeline()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/app/sync.py DELETED
@@ -1,80 +0,0 @@
1
- import time
2
- import threading
3
- import subprocess
4
-
5
- from typing import List, Union
6
-
7
-
8
- class Channel:
9
- def __init__(self, source, destination, sync_deletions=False, every=60, exclude: Union[str, List, None] = None):
10
- self.source = source
11
- self.destination = destination
12
- self.event = threading.Event()
13
- self.syncing_thread = threading.Thread(target=self._sync, args=())
14
- self.sync_deletions = sync_deletions
15
- self.every = every
16
-
17
-
18
- if not exclude: exclude = []
19
- if isinstance(exclude,str): exclude = [exclude]
20
-
21
- self.exclude = exclude
22
- self.command = ['rsync', '-aP']
23
-
24
-
25
- def alive(self):
26
- if self.syncing_thread.is_alive(): return True
27
- else: return False
28
-
29
-
30
- def _sync(self):
31
- command = self.command
32
-
33
- for exclusion in self.exclude:
34
- command.append(f'--exclude={exclusion}')
35
-
36
- command.extend([f'{self.source}/', f'{self.destination}/'])
37
-
38
- if self.sync_deletions: command.append('--delete')
39
-
40
- while not self.event.is_set():
41
- subprocess.run(command)
42
- time.sleep(self.every)
43
-
44
-
45
- def copy(self):
46
- command = self.command
47
-
48
- for exclusion in self.exclude:
49
- command.append(f'--exclude={exclusion}')
50
-
51
- command.extend([f'{self.source}/', f'{self.destination}/'])
52
-
53
- if self.sync_deletions: command.append('--delete')
54
- subprocess.run(command)
55
-
56
- return True
57
-
58
-
59
- def start(self):
60
- if self.syncing_thread.is_alive():
61
- self.event.set()
62
- self.syncing_thread.join()
63
-
64
- if self.event.is_set(): self.event.clear()
65
- if self.syncing_thread._started.is_set(): self.syncing_thread = threading.Thread(target=self._sync, args=())
66
-
67
- self.syncing_thread.start()
68
-
69
- return self.alive()
70
-
71
-
72
- def stop(self):
73
- if self.alive():
74
- self.event.set()
75
- self.syncing_thread.join()
76
-
77
- while self.alive():
78
- if not self.alive(): break
79
-
80
- return not self.alive()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.json DELETED
@@ -1,229 +0,0 @@
1
- {
2
- "language": "vi-VN",
3
- "support_language": [
4
- "en-US",
5
- "vi-VN"
6
- ],
7
-
8
-
9
- "theme": "NoCrypt/miku",
10
- "themes": [
11
- "NoCrypt/miku",
12
- "gstaff/xkcd",
13
- "JohnSmith9982/small_and_pretty",
14
- "ParityError/Interstellar",
15
- "earneleh/paris",
16
- "shivi/calm_seafoam",
17
- "Hev832/Applio",
18
- "YTheme/Minecraft",
19
- "gstaff/sketch",
20
- "SebastianBravo/simci_css",
21
- "allenai/gradio-theme",
22
- "Nymbo/Nymbo_Theme_5",
23
- "lone17/kotaemon",
24
- "Zarkel/IBM_Carbon_Theme",
25
- "SherlockRamos/Feliz",
26
- "freddyaboulton/dracula_revamped",
27
- "freddyaboulton/bad-theme-space",
28
- "gradio/dracula_revamped",
29
- "abidlabs/dracula_revamped",
30
- "gradio/dracula_test",
31
- "gradio/seafoam",
32
- "gradio/glass",
33
- "gradio/monochrome",
34
- "gradio/soft",
35
- "gradio/default",
36
- "gradio/base",
37
- "abidlabs/pakistan",
38
- "dawood/microsoft_windows",
39
- "ysharma/steampunk",
40
- "ysharma/huggingface",
41
- "abidlabs/Lime",
42
- "freddyaboulton/this-theme-does-not-exist-2",
43
- "aliabid94/new-theme",
44
- "aliabid94/test2",
45
- "aliabid94/test3",
46
- "aliabid94/test4",
47
- "abidlabs/banana",
48
- "freddyaboulton/test-blue",
49
- "gstaff/whiteboard",
50
- "ysharma/llamas",
51
- "abidlabs/font-test",
52
- "YenLai/Superhuman",
53
- "bethecloud/storj_theme",
54
- "sudeepshouche/minimalist",
55
- "knotdgaf/gradiotest",
56
- "ParityError/Anime",
57
- "Ajaxon6255/Emerald_Isle",
58
- "ParityError/LimeFace",
59
- "finlaymacklon/smooth_slate",
60
- "finlaymacklon/boxy_violet",
61
- "derekzen/stardust",
62
- "EveryPizza/Cartoony-Gradio-Theme",
63
- "Ifeanyi/Cyanister",
64
- "Tshackelton/IBMPlex-DenseReadable",
65
- "snehilsanyal/scikit-learn",
66
- "Himhimhim/xkcd",
67
- "nota-ai/theme",
68
- "rawrsor1/Everforest",
69
- "rottenlittlecreature/Moon_Goblin",
70
- "abidlabs/test-yellow",
71
- "abidlabs/test-yellow3",
72
- "idspicQstitho/dracula_revamped",
73
- "kfahn/AnimalPose",
74
- "HaleyCH/HaleyCH_Theme",
75
- "simulKitke/dracula_test",
76
- "braintacles/CrimsonNight",
77
- "wentaohe/whiteboardv2",
78
- "reilnuud/polite",
79
- "remilia/Ghostly",
80
- "Franklisi/darkmode",
81
- "coding-alt/soft",
82
- "xiaobaiyuan/theme_land",
83
- "step-3-profit/Midnight-Deep",
84
- "xiaobaiyuan/theme_demo",
85
- "Taithrah/Minimal",
86
- "Insuz/SimpleIndigo",
87
- "zkunn/Alipay_Gradio_theme",
88
- "Insuz/Mocha",
89
- "xiaobaiyuan/theme_brief",
90
- "Ama434/434-base-Barlow",
91
- "Ama434/def_barlow",
92
- "Ama434/neutral-barlow",
93
- "dawood/dracula_test",
94
- "nuttea/Softblue",
95
- "BlueDancer/Alien_Diffusion",
96
- "naughtondale/monochrome",
97
- "Dagfinn1962/standard",
98
- "default"
99
- ],
100
-
101
-
102
- "tts_voice": [
103
- "af-ZA-AdriNeural", "af-ZA-WillemNeural", "sq-AL-AnilaNeural",
104
- "sq-AL-IlirNeural", "am-ET-AmehaNeural", "am-ET-MekdesNeural",
105
- "ar-DZ-AminaNeural", "ar-DZ-IsmaelNeural", "ar-BH-AliNeural",
106
- "ar-BH-LailaNeural", "ar-EG-SalmaNeural", "ar-EG-ShakirNeural",
107
- "ar-IQ-BasselNeural", "ar-IQ-RanaNeural", "ar-JO-SanaNeural",
108
- "ar-JO-TaimNeural", "ar-KW-FahedNeural", "ar-KW-NouraNeural",
109
- "ar-LB-LaylaNeural", "ar-LB-RamiNeural", "ar-LY-ImanNeural",
110
- "ar-LY-OmarNeural", "ar-MA-JamalNeural", "ar-MA-MounaNeural",
111
- "ar-OM-AbdullahNeural", "ar-OM-AyshaNeural", "ar-QA-AmalNeural",
112
- "ar-QA-MoazNeural", "ar-SA-HamedNeural", "ar-SA-ZariyahNeural",
113
- "ar-SY-AmanyNeural", "ar-SY-LaithNeural", "ar-TN-HediNeural",
114
- "ar-TN-ReemNeural", "ar-AE-FatimaNeural", "ar-AE-HamdanNeural",
115
- "ar-YE-MaryamNeural", "ar-YE-SalehNeural", "az-AZ-BabekNeural",
116
- "az-AZ-BanuNeural", "bn-BD-NabanitaNeural", "bn-BD-PradeepNeural",
117
- "bn-IN-BashkarNeural", "bn-IN-TanishaaNeural", "bs-BA-GoranNeural",
118
- "bs-BA-VesnaNeural", "bg-BG-BorislavNeural", "bg-BG-KalinaNeural",
119
- "my-MM-NilarNeural", "my-MM-ThihaNeural", "ca-ES-EnricNeural",
120
- "ca-ES-JoanaNeural", "zh-HK-HiuGaaiNeural", "zh-HK-HiuMaanNeural",
121
- "zh-HK-WanLungNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
122
- "zh-CN-YunjianNeural", "zh-CN-YunxiNeural", "zh-CN-YunxiaNeural",
123
- "zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-TW-HsiaoChenNeural",
124
- "zh-TW-YunJheNeural", "zh-TW-HsiaoYuNeural", "zh-CN-shaanxi-XiaoniNeural",
125
- "hr-HR-GabrijelaNeural", "hr-HR-SreckoNeural", "cs-CZ-AntoninNeural",
126
- "cs-CZ-VlastaNeural", "da-DK-ChristelNeural", "da-DK-JeppeNeural",
127
- "nl-BE-ArnaudNeural", "nl-BE-DenaNeural", "nl-NL-ColetteNeural",
128
- "nl-NL-FennaNeural", "nl-NL-MaartenNeural", "en-AU-NatashaNeural",
129
- "en-AU-WilliamNeural", "en-CA-ClaraNeural", "en-CA-LiamNeural",
130
- "en-HK-SamNeural", "en-HK-YanNeural", "en-IN-NeerjaExpressiveNeural",
131
- "en-IN-NeerjaNeural", "en-IN-PrabhatNeural", "en-IE-ConnorNeural",
132
- "en-IE-EmilyNeural", "en-KE-AsiliaNeural", "en-KE-ChilembaNeural",
133
- "en-NZ-MitchellNeural", "en-NZ-MollyNeural", "en-NG-AbeoNeural",
134
- "en-NG-EzinneNeural", "en-PH-JamesNeural", "en-PH-RosaNeural",
135
- "en-SG-LunaNeural", "en-SG-WayneNeural", "en-ZA-LeahNeural",
136
- "en-ZA-LukeNeural", "en-TZ-ElimuNeural", "en-TZ-ImaniNeural",
137
- "en-GB-LibbyNeural", "en-GB-MaisieNeural", "en-GB-RyanNeural",
138
- "en-GB-SoniaNeural", "en-GB-ThomasNeural", "en-US-AvaMultilingualNeural",
139
- "en-US-AndrewMultilingualNeural", "en-US-EmmaMultilingualNeural",
140
- "en-US-BrianMultilingualNeural", "en-US-AvaNeural", "en-US-AndrewNeural",
141
- "en-US-EmmaNeural", "en-US-BrianNeural", "en-US-AnaNeural", "en-US-AriaNeural",
142
- "en-US-ChristopherNeural", "en-US-EricNeural", "en-US-GuyNeural",
143
- "en-US-JennyNeural", "en-US-MichelleNeural", "en-US-RogerNeural",
144
- "en-US-SteffanNeural", "et-EE-AnuNeural", "et-EE-KertNeural",
145
- "fil-PH-AngeloNeural", "fil-PH-BlessicaNeural", "fi-FI-HarriNeural",
146
- "fi-FI-NooraNeural", "fr-BE-CharlineNeural", "fr-BE-GerardNeural",
147
- "fr-CA-ThierryNeural", "fr-CA-AntoineNeural", "fr-CA-JeanNeural",
148
- "fr-CA-SylvieNeural", "fr-FR-VivienneMultilingualNeural", "fr-FR-RemyMultilingualNeural",
149
- "fr-FR-DeniseNeural", "fr-FR-EloiseNeural", "fr-FR-HenriNeural",
150
- "fr-CH-ArianeNeural", "fr-CH-FabriceNeural", "gl-ES-RoiNeural",
151
- "gl-ES-SabelaNeural", "ka-GE-EkaNeural", "ka-GE-GiorgiNeural",
152
- "de-AT-IngridNeural", "de-AT-JonasNeural", "de-DE-SeraphinaMultilingualNeural",
153
- "de-DE-FlorianMultilingualNeural", "de-DE-AmalaNeural", "de-DE-ConradNeural",
154
- "de-DE-KatjaNeural", "de-DE-KillianNeural", "de-CH-JanNeural",
155
- "de-CH-LeniNeural", "el-GR-AthinaNeural", "el-GR-NestorasNeural",
156
- "gu-IN-DhwaniNeural", "gu-IN-NiranjanNeural", "he-IL-AvriNeural",
157
- "he-IL-HilaNeural", "hi-IN-MadhurNeural", "hi-IN-SwaraNeural",
158
- "hu-HU-NoemiNeural", "hu-HU-TamasNeural", "is-IS-GudrunNeural",
159
- "is-IS-GunnarNeural", "id-ID-ArdiNeural", "id-ID-GadisNeural",
160
- "ga-IE-ColmNeural", "ga-IE-OrlaNeural", "it-IT-GiuseppeNeural",
161
- "it-IT-DiegoNeural", "it-IT-ElsaNeural", "it-IT-IsabellaNeural",
162
- "ja-JP-KeitaNeural", "ja-JP-NanamiNeural", "jv-ID-DimasNeural",
163
- "jv-ID-SitiNeural", "kn-IN-GaganNeural", "kn-IN-SapnaNeural",
164
- "kk-KZ-AigulNeural", "kk-KZ-DauletNeural", "km-KH-PisethNeural",
165
- "km-KH-SreymomNeural", "ko-KR-HyunsuNeural", "ko-KR-InJoonNeural",
166
- "ko-KR-SunHiNeural", "lo-LA-ChanthavongNeural", "lo-LA-KeomanyNeural",
167
- "lv-LV-EveritaNeural", "lv-LV-NilsNeural", "lt-LT-LeonasNeural",
168
- "lt-LT-OnaNeural", "mk-MK-AleksandarNeural", "mk-MK-MarijaNeural",
169
- "ms-MY-OsmanNeural", "ms-MY-YasminNeural", "ml-IN-MidhunNeural",
170
- "ml-IN-SobhanaNeural", "mt-MT-GraceNeural", "mt-MT-JosephNeural",
171
- "mr-IN-AarohiNeural", "mr-IN-ManoharNeural", "mn-MN-BataaNeural",
172
- "mn-MN-YesuiNeural", "ne-NP-HemkalaNeural", "ne-NP-SagarNeural",
173
- "nb-NO-FinnNeural", "nb-NO-PernilleNeural", "ps-AF-GulNawazNeural",
174
- "ps-AF-LatifaNeural", "fa-IR-DilaraNeural", "fa-IR-FaridNeural",
175
- "pl-PL-MarekNeural", "pl-PL-ZofiaNeural", "pt-BR-ThalitaNeural",
176
- "pt-BR-AntonioNeural", "pt-BR-FranciscaNeural", "pt-PT-DuarteNeural",
177
- "pt-PT-RaquelNeural", "ro-RO-AlinaNeural", "ro-RO-EmilNeural",
178
- "ru-RU-DmitryNeural", "ru-RU-SvetlanaNeural", "sr-RS-NicholasNeural",
179
- "sr-RS-SophieNeural", "si-LK-SameeraNeural", "si-LK-ThiliniNeural",
180
- "sk-SK-LukasNeural", "sk-SK-ViktoriaNeural", "sl-SI-PetraNeural",
181
- "sl-SI-RokNeural", "so-SO-MuuseNeural", "so-SO-UbaxNeural",
182
- "es-AR-ElenaNeural", "es-AR-TomasNeural", "es-BO-MarceloNeural",
183
- "es-BO-SofiaNeural", "es-CL-CatalinaNeural", "es-CL-LorenzoNeural",
184
- "es-ES-XimenaNeural", "es-CO-GonzaloNeural", "es-CO-SalomeNeural",
185
- "es-CR-JuanNeural", "es-CR-MariaNeural", "es-CU-BelkysNeural",
186
- "es-CU-ManuelNeural", "es-DO-EmilioNeural", "es-DO-RamonaNeural",
187
- "es-EC-AndreaNeural", "es-EC-LuisNeural", "es-SV-LorenaNeural",
188
- "es-SV-RodrigoNeural", "es-GQ-JavierNeural", "es-GQ-TeresaNeural",
189
- "es-GT-AndresNeural", "es-GT-MartaNeural", "es-HN-CarlosNeural",
190
- "es-HN-KarlaNeural", "es-MX-DaliaNeural", "es-MX-JorgeNeural",
191
- "es-NI-FedericoNeural", "es-NI-YolandaNeural", "es-PA-MargaritaNeural",
192
- "es-PA-RobertoNeural", "es-PY-MarioNeural", "es-PY-TaniaNeural",
193
- "es-PE-AlexNeural", "es-PE-CamilaNeural", "es-PR-KarinaNeural",
194
- "es-PR-VictorNeural", "es-ES-AlvaroNeural", "es-ES-ElviraNeural",
195
- "es-US-AlonsoNeural", "es-US-PalomaNeural", "es-UY-MateoNeural",
196
- "es-UY-ValentinaNeural", "es-VE-PaolaNeural", "es-VE-SebastianNeural",
197
- "su-ID-JajangNeural", "su-ID-TutiNeural", "sw-KE-RafikiNeural",
198
- "sw-KE-ZuriNeural", "sw-TZ-DaudiNeural", "sw-TZ-RehemaNeural",
199
- "sv-SE-MattiasNeural", "sv-SE-SofieNeural", "ta-IN-PallaviNeural",
200
- "ta-IN-ValluvarNeural", "ta-MY-KaniNeural", "ta-MY-SuryaNeural",
201
- "ta-SG-AnbuNeural", "ta-SG-VenbaNeural", "ta-LK-KumarNeural",
202
- "ta-LK-SaranyaNeural", "te-IN-MohanNeural", "te-IN-ShrutiNeural",
203
- "th-TH-NiwatNeural", "th-TH-PremwadeeNeural", "tr-TR-AhmetNeural",
204
- "tr-TR-EmelNeural", "uk-UA-OstapNeural", "uk-UA-PolinaNeural",
205
- "ur-IN-GulNeural", "ur-IN-SalmanNeural", "ur-PK-AsadNeural",
206
- "ur-PK-UzmaNeural", "uz-UZ-MadinaNeural", "uz-UZ-SardorNeural",
207
- "vi-VN-HoaiMyNeural", "vi-VN-NamMinhNeural", "cy-GB-AledNeural",
208
- "cy-GB-NiaNeural", "zu-ZA-ThandoNeural", "zu-ZA-ThembaNeural"
209
- ],
210
-
211
-
212
- "separator_tab": false,
213
- "convert_tab": true,
214
- "tts_tab": true,
215
- "effects_tab": true,
216
- "create_dataset_tab": false,
217
- "training_tab": false,
218
- "fushion_tab": false,
219
- "read_tab": false,
220
- "downloads_tab": true,
221
- "settings_tab": false,
222
-
223
- "app_port": 7860,
224
- "tensorboard_port": 6870,
225
- "num_of_restart": 5,
226
- "server_name": "0.0.0.0",
227
- "app_show_error": true,
228
- "share": false
229
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/config.py DELETED
@@ -1,120 +0,0 @@
1
- import os
2
- import json
3
- import torch
4
-
5
-
6
- version_config_paths = [
7
- os.path.join("v1", "32000.json"),
8
- os.path.join("v1", "40000.json"),
9
- os.path.join("v1", "48000.json"),
10
- os.path.join("v2", "48000.json"),
11
- os.path.join("v2", "40000.json"),
12
- os.path.join("v2", "32000.json"),
13
- ]
14
-
15
-
16
- def singleton(cls):
17
- instances = {}
18
-
19
- def get_instance(*args, **kwargs):
20
- if cls not in instances: instances[cls] = cls(*args, **kwargs)
21
-
22
- return instances[cls]
23
- return get_instance
24
-
25
-
26
- @singleton
27
- class Config:
28
- def __init__(self):
29
- self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
30
- self.is_half = self.device != "cpu"
31
- self.gpu_name = (torch.cuda.get_device_name(int(self.device.split(":")[-1])) if self.device.startswith("cuda") else None)
32
- self.json_config = self.load_config_json()
33
- self.translations = self.multi_language()
34
- self.gpu_mem = None
35
- self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
36
-
37
-
38
- def load_config_json(self) -> dict:
39
- configs = {}
40
-
41
- for config_file in version_config_paths:
42
- config_path = os.path.join("main", "configs", config_file)
43
-
44
- with open(config_path, "r") as f:
45
- configs[config_file] = json.load(f)
46
-
47
- return configs
48
-
49
- def multi_language(self):
50
- with open(os.path.join("main", "configs", "config.json"), "r") as f:
51
- configs = json.load(f)
52
-
53
- lang = configs["language"]
54
-
55
- if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
56
-
57
- if not lang: lang = "vi-VN"
58
- if lang not in configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
59
-
60
- lang_path = os.path.join("assets", "languages", f"{lang}.json")
61
- if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", f"vi-VN.json")
62
-
63
- with open(lang_path, encoding="utf-8") as f:
64
- translations = json.load(f)
65
-
66
-
67
- return translations
68
-
69
- def set_precision(self, precision):
70
- if precision not in ["fp32", "fp16"]: raise ValueError("Loại chính xác không hợp lệ. Phải là 'fp32' hoặc 'fp16'(Invalid precision type. Must be 'fp32' or 'fp16').")
71
-
72
- fp16_run_value = precision == "fp16"
73
-
74
- for config_path in version_config_paths:
75
- full_config_path = os.path.join("main", "configs", config_path)
76
-
77
- try:
78
- with open(full_config_path, "r") as f:
79
- config = json.load(f)
80
-
81
- config["train"]["fp16_run"] = fp16_run_value
82
-
83
- with open(full_config_path, "w") as f:
84
- json.dump(config, f, indent=4)
85
- except FileNotFoundError:
86
- print(self.translations["not_found"].format(name=full_config_path))
87
-
88
- return self.translations["set_precision"].format(precision=precision)
89
-
90
- def device_config(self) -> tuple:
91
- if self.device.startswith("cuda"):
92
- self.set_cuda_config()
93
- elif self.has_mps():
94
- self.device = "mps"
95
- self.is_half = False
96
- self.set_precision("fp32")
97
- else:
98
- self.device = "cpu"
99
- self.is_half = False
100
- self.set_precision("fp32")
101
-
102
- x_pad, x_query, x_center, x_max = ((3, 10, 60, 65) if self.is_half else (1, 6, 38, 41))
103
-
104
- if self.gpu_mem is not None and self.gpu_mem <= 4: x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
105
-
106
- return x_pad, x_query, x_center, x_max
107
-
108
- def set_cuda_config(self):
109
- i_device = int(self.device.split(":")[-1])
110
- self.gpu_name = torch.cuda.get_device_name(i_device)
111
- low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
112
-
113
- if (any(gpu in self.gpu_name for gpu in low_end_gpus) and "V100" not in self.gpu_name.upper()):
114
- self.is_half = False
115
- self.set_precision("fp32")
116
-
117
- self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
118
-
119
- def has_mps(self) -> bool:
120
- return torch.backends.mps.is_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/32000.json DELETED
@@ -1,82 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [
8
- 0.8,
9
- 0.99
10
- ],
11
- "eps": 1e-09,
12
- "batch_size": 4,
13
- "fp16_run": false,
14
- "lr_decay": 0.999875,
15
- "segment_size": 12800,
16
- "init_lr_ratio": 1,
17
- "warmup_epochs": 0,
18
- "c_mel": 45,
19
- "c_kl": 1.0
20
- },
21
- "data": {
22
- "max_wav_value": 32768.0,
23
- "sample_rate": 32000,
24
- "filter_length": 1024,
25
- "hop_length": 320,
26
- "win_length": 1024,
27
- "n_mel_channels": 80,
28
- "mel_fmin": 0.0,
29
- "mel_fmax": null
30
- },
31
- "model": {
32
- "inter_channels": 192,
33
- "hidden_channels": 192,
34
- "filter_channels": 768,
35
- "text_enc_hidden_dim": 256,
36
- "n_heads": 2,
37
- "n_layers": 6,
38
- "kernel_size": 3,
39
- "p_dropout": 0,
40
- "resblock": "1",
41
- "resblock_kernel_sizes": [
42
- 3,
43
- 7,
44
- 11
45
- ],
46
- "resblock_dilation_sizes": [
47
- [
48
- 1,
49
- 3,
50
- 5
51
- ],
52
- [
53
- 1,
54
- 3,
55
- 5
56
- ],
57
- [
58
- 1,
59
- 3,
60
- 5
61
- ]
62
- ],
63
- "upsample_rates": [
64
- 10,
65
- 4,
66
- 2,
67
- 2,
68
- 2
69
- ],
70
- "upsample_initial_channel": 512,
71
- "upsample_kernel_sizes": [
72
- 16,
73
- 16,
74
- 4,
75
- 4,
76
- 4
77
- ],
78
- "use_spectral_norm": false,
79
- "gin_channels": 256,
80
- "spk_embed_dim": 109
81
- }
82
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/40000.json DELETED
@@ -1,80 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [
8
- 0.8,
9
- 0.99
10
- ],
11
- "eps": 1e-09,
12
- "batch_size": 4,
13
- "fp16_run": false,
14
- "lr_decay": 0.999875,
15
- "segment_size": 12800,
16
- "init_lr_ratio": 1,
17
- "warmup_epochs": 0,
18
- "c_mel": 45,
19
- "c_kl": 1.0
20
- },
21
- "data": {
22
- "max_wav_value": 32768.0,
23
- "sample_rate": 40000,
24
- "filter_length": 2048,
25
- "hop_length": 400,
26
- "win_length": 2048,
27
- "n_mel_channels": 125,
28
- "mel_fmin": 0.0,
29
- "mel_fmax": null
30
- },
31
- "model": {
32
- "inter_channels": 192,
33
- "hidden_channels": 192,
34
- "filter_channels": 768,
35
- "text_enc_hidden_dim": 256,
36
- "n_heads": 2,
37
- "n_layers": 6,
38
- "kernel_size": 3,
39
- "p_dropout": 0,
40
- "resblock": "1",
41
- "resblock_kernel_sizes": [
42
- 3,
43
- 7,
44
- 11
45
- ],
46
- "resblock_dilation_sizes": [
47
- [
48
- 1,
49
- 3,
50
- 5
51
- ],
52
- [
53
- 1,
54
- 3,
55
- 5
56
- ],
57
- [
58
- 1,
59
- 3,
60
- 5
61
- ]
62
- ],
63
- "upsample_rates": [
64
- 10,
65
- 10,
66
- 2,
67
- 2
68
- ],
69
- "upsample_initial_channel": 512,
70
- "upsample_kernel_sizes": [
71
- 16,
72
- 16,
73
- 4,
74
- 4
75
- ],
76
- "use_spectral_norm": false,
77
- "gin_channels": 256,
78
- "spk_embed_dim": 109
79
- }
80
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v1/48000.json DELETED
@@ -1,82 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "epochs": 20000,
6
- "learning_rate": 0.0001,
7
- "betas": [
8
- 0.8,
9
- 0.99
10
- ],
11
- "eps": 1e-09,
12
- "batch_size": 4,
13
- "fp16_run": false,
14
- "lr_decay": 0.999875,
15
- "segment_size": 11520,
16
- "init_lr_ratio": 1,
17
- "warmup_epochs": 0,
18
- "c_mel": 45,
19
- "c_kl": 1.0
20
- },
21
- "data": {
22
- "max_wav_value": 32768.0,
23
- "sample_rate": 48000,
24
- "filter_length": 2048,
25
- "hop_length": 480,
26
- "win_length": 2048,
27
- "n_mel_channels": 128,
28
- "mel_fmin": 0.0,
29
- "mel_fmax": null
30
- },
31
- "model": {
32
- "inter_channels": 192,
33
- "hidden_channels": 192,
34
- "filter_channels": 768,
35
- "text_enc_hidden_dim": 256,
36
- "n_heads": 2,
37
- "n_layers": 6,
38
- "kernel_size": 3,
39
- "p_dropout": 0,
40
- "resblock": "1",
41
- "resblock_kernel_sizes": [
42
- 3,
43
- 7,
44
- 11
45
- ],
46
- "resblock_dilation_sizes": [
47
- [
48
- 1,
49
- 3,
50
- 5
51
- ],
52
- [
53
- 1,
54
- 3,
55
- 5
56
- ],
57
- [
58
- 1,
59
- 3,
60
- 5
61
- ]
62
- ],
63
- "upsample_rates": [
64
- 10,
65
- 6,
66
- 2,
67
- 2,
68
- 2
69
- ],
70
- "upsample_initial_channel": 512,
71
- "upsample_kernel_sizes": [
72
- 16,
73
- 16,
74
- 4,
75
- 4,
76
- 4
77
- ],
78
- "use_spectral_norm": false,
79
- "gin_channels": 256,
80
- "spk_embed_dim": 109
81
- }
82
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/32000.json DELETED
@@ -1,76 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [
7
- 0.8,
8
- 0.99
9
- ],
10
- "eps": 1e-09,
11
- "fp16_run": false,
12
- "lr_decay": 0.999875,
13
- "segment_size": 12800,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 32000,
20
- "filter_length": 1024,
21
- "hop_length": 320,
22
- "win_length": 1024,
23
- "n_mel_channels": 80,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 768,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [
38
- 3,
39
- 7,
40
- 11
41
- ],
42
- "resblock_dilation_sizes": [
43
- [
44
- 1,
45
- 3,
46
- 5
47
- ],
48
- [
49
- 1,
50
- 3,
51
- 5
52
- ],
53
- [
54
- 1,
55
- 3,
56
- 5
57
- ]
58
- ],
59
- "upsample_rates": [
60
- 10,
61
- 8,
62
- 2,
63
- 2
64
- ],
65
- "upsample_initial_channel": 512,
66
- "upsample_kernel_sizes": [
67
- 20,
68
- 16,
69
- 4,
70
- 4
71
- ],
72
- "use_spectral_norm": false,
73
- "gin_channels": 256,
74
- "spk_embed_dim": 109
75
- }
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/40000.json DELETED
@@ -1,76 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [
7
- 0.8,
8
- 0.99
9
- ],
10
- "eps": 1e-09,
11
- "fp16_run": false,
12
- "lr_decay": 0.999875,
13
- "segment_size": 12800,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 40000,
20
- "filter_length": 2048,
21
- "hop_length": 400,
22
- "win_length": 2048,
23
- "n_mel_channels": 125,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 768,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [
38
- 3,
39
- 7,
40
- 11
41
- ],
42
- "resblock_dilation_sizes": [
43
- [
44
- 1,
45
- 3,
46
- 5
47
- ],
48
- [
49
- 1,
50
- 3,
51
- 5
52
- ],
53
- [
54
- 1,
55
- 3,
56
- 5
57
- ]
58
- ],
59
- "upsample_rates": [
60
- 10,
61
- 10,
62
- 2,
63
- 2
64
- ],
65
- "upsample_initial_channel": 512,
66
- "upsample_kernel_sizes": [
67
- 16,
68
- 16,
69
- 4,
70
- 4
71
- ],
72
- "use_spectral_norm": false,
73
- "gin_channels": 256,
74
- "spk_embed_dim": 109
75
- }
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/configs/v2/48000.json DELETED
@@ -1,76 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "seed": 1234,
5
- "learning_rate": 0.0001,
6
- "betas": [
7
- 0.8,
8
- 0.99
9
- ],
10
- "eps": 1e-09,
11
- "fp16_run": false,
12
- "lr_decay": 0.999875,
13
- "segment_size": 17280,
14
- "c_mel": 45,
15
- "c_kl": 1.0
16
- },
17
- "data": {
18
- "max_wav_value": 32768.0,
19
- "sample_rate": 48000,
20
- "filter_length": 2048,
21
- "hop_length": 480,
22
- "win_length": 2048,
23
- "n_mel_channels": 128,
24
- "mel_fmin": 0.0,
25
- "mel_fmax": null
26
- },
27
- "model": {
28
- "inter_channels": 192,
29
- "hidden_channels": 192,
30
- "filter_channels": 768,
31
- "text_enc_hidden_dim": 768,
32
- "n_heads": 2,
33
- "n_layers": 6,
34
- "kernel_size": 3,
35
- "p_dropout": 0,
36
- "resblock": "1",
37
- "resblock_kernel_sizes": [
38
- 3,
39
- 7,
40
- 11
41
- ],
42
- "resblock_dilation_sizes": [
43
- [
44
- 1,
45
- 3,
46
- 5
47
- ],
48
- [
49
- 1,
50
- 3,
51
- 5
52
- ],
53
- [
54
- 1,
55
- 3,
56
- 5
57
- ]
58
- ],
59
- "upsample_rates": [
60
- 12,
61
- 10,
62
- 2,
63
- 2
64
- ],
65
- "upsample_initial_channel": 512,
66
- "upsample_kernel_sizes": [
67
- 24,
68
- 20,
69
- 4,
70
- 4
71
- ],
72
- "use_spectral_norm": false,
73
- "gin_channels": 256,
74
- "spk_embed_dim": 109
75
- }
76
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/audio_effects.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import sys
3
- import librosa
4
- import argparse
5
-
6
- import numpy as np
7
- import soundfile as sf
8
-
9
- from distutils.util import strtobool
10
- from scipy.signal import butter, filtfilt
11
- from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser
12
-
13
- now_dir = os.getcwd()
14
- sys.path.append(now_dir)
15
-
16
- from main.configs.config import Config
17
- translations = Config().translations
18
-
19
- def parse_arguments() -> tuple:
20
- parser = argparse.ArgumentParser()
21
- parser.add_argument("--input_path", type=str, required=True)
22
- parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
23
- parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
24
- parser.add_argument("--resample_sr", type=int, default=0)
25
- parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
26
- parser.add_argument("--chorus_depth", type=float, default=0.5)
27
- parser.add_argument("--chorus_rate", type=float, default=1.5)
28
- parser.add_argument("--chorus_mix", type=float, default=0.5)
29
- parser.add_argument("--chorus_delay", type=int, default=10)
30
- parser.add_argument("--chorus_feedback", type=float, default=0)
31
- parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
32
- parser.add_argument("--drive_db", type=int, default=20)
33
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
34
- parser.add_argument("--reverb_room_size", type=float, default=0.5)
35
- parser.add_argument("--reverb_damping", type=float, default=0.5)
36
- parser.add_argument("--reverb_wet_level", type=float, default=0.33)
37
- parser.add_argument("--reverb_dry_level", type=float, default=0.67)
38
- parser.add_argument("--reverb_width", type=float, default=1)
39
- parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
40
- parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
41
- parser.add_argument("--pitch_shift", type=int, default=0)
42
- parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
43
- parser.add_argument("--delay_seconds", type=float, default=0.5)
44
- parser.add_argument("--delay_feedback", type=float, default=0.5)
45
- parser.add_argument("--delay_mix", type=float, default=0.5)
46
- parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
47
- parser.add_argument("--compressor_threshold", type=int, default=-20)
48
- parser.add_argument("--compressor_ratio", type=float, default=4)
49
- parser.add_argument("--compressor_attack_ms", type=float, default=10)
50
- parser.add_argument("--compressor_release_ms", type=int, default=200)
51
- parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
52
- parser.add_argument("--limiter_threshold", type=int, default=0)
53
- parser.add_argument("--limiter_release", type=int, default=100)
54
- parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
55
- parser.add_argument("--gain_db", type=int, default=0)
56
- parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
57
- parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
58
- parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
59
- parser.add_argument("--clipping_threshold", type=int, default=-10)
60
- parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
61
- parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
62
- parser.add_argument("--phaser_depth", type=float, default=0.5)
63
- parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
64
- parser.add_argument("--phaser_feedback", type=float, default=0)
65
- parser.add_argument("--phaser_mix", type=float, default=0.5)
66
- parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
67
- parser.add_argument("--bass_boost_db", type=int, default=0)
68
- parser.add_argument("--bass_boost_frequency", type=int, default=100)
69
- parser.add_argument("--treble_boost_db", type=int, default=0)
70
- parser.add_argument("--treble_boost_frequency", type=int, default=3000)
71
- parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
72
- parser.add_argument("--fade_in_duration", type=float, default=2000)
73
- parser.add_argument("--fade_out_duration", type=float, default=2000)
74
- parser.add_argument("--export_format", type=str, default="wav")
75
-
76
- args = parser.parse_args()
77
- return args
78
-
79
-
80
- def main():
81
- args = parse_arguments()
82
-
83
- process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out)
84
-
85
-
86
- def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out):
87
- def bass_boost(audio, gain_db, frequency, sample_rate):
88
- if gain_db >= 1:
89
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
90
-
91
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
92
- else: return audio
93
-
94
- def treble_boost(audio, gain_db, frequency, sample_rate):
95
- if gain_db >=1:
96
- b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
97
-
98
- return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
99
- else: return audio
100
-
101
- def fade_out_effect(audio, sr, duration=3.0):
102
- length = int(duration * sr)
103
- end = audio.shape[0]
104
-
105
- if length > end: length = end
106
- start = end - length
107
-
108
- audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
109
- return audio
110
-
111
- def fade_in_effect(audio, sr, duration=3.0):
112
- length = int(duration * sr)
113
- start = 0
114
-
115
- if length > audio.shape[0]: length = audio.shape[0]
116
- end = length
117
-
118
- audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
119
- return audio
120
-
121
- if not input_path or not os.path.exists(input_path):
122
- print(translations["input_not_valid"])
123
- sys.exit(1)
124
-
125
- if not output_path:
126
- print(translations["output_not_valid"])
127
- sys.exit(1)
128
-
129
- if os.path.exists(output_path): os.remove(output_path)
130
-
131
- try:
132
- audio, sample_rate = sf.read(input_path)
133
- except Exception as e:
134
- raise RuntimeError(translations["errors_loading_audio"].format(e=e))
135
-
136
- try:
137
- board = Pedalboard()
138
-
139
- if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
140
- if distortion: board.append(Distortion(drive_db=distortion_drive))
141
- if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
142
- if pitchshift: board.append(PitchShift(semitones=pitch_shift))
143
- if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
144
- if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
145
- if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
146
- if gain: board.append(Gain(gain_db=gain_db))
147
- if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
148
- if clipping: board.append(Clipping(threshold_db=clipping_threshold))
149
- if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
150
-
151
- processed_audio = board(audio, sample_rate)
152
-
153
- if treble_bass_boost:
154
- processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
155
- processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
156
-
157
- if fade_in_out:
158
- processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
159
- processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
160
-
161
- if resample_sr != sample_rate and resample_sr > 0 and resample:
162
- processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=resample_sr)
163
- sample_rate = resample_sr
164
- except Exception as e:
165
- raise RuntimeError(translations["apply_error"].format(e=e))
166
-
167
- sf.write(output_path.replace(".wav", f".{export_format}"), processed_audio, sample_rate, format=export_format)
168
- return output_path
169
-
170
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/convert.py DELETED
@@ -1,1060 +0,0 @@
1
- import gc
2
- import re
3
- import os
4
- import sys
5
- import time
6
- import torch
7
- import faiss
8
- import shutil
9
- import codecs
10
- import pyworld
11
- import librosa
12
- import logging
13
- import argparse
14
- import warnings
15
- import traceback
16
- import torchcrepe
17
- import subprocess
18
- import parselmouth
19
- import logging.handlers
20
-
21
- import numpy as np
22
- import soundfile as sf
23
- import noisereduce as nr
24
- import torch.nn.functional as F
25
- import torch.multiprocessing as mp
26
-
27
- from tqdm import tqdm
28
- from scipy import signal
29
- from torch import Tensor
30
- from scipy.io import wavfile
31
- from audio_upscaler import upscale
32
- from distutils.util import strtobool
33
- from fairseq import checkpoint_utils
34
- from pydub import AudioSegment, silence
35
-
36
-
37
- now_dir = os.getcwd()
38
- sys.path.append(now_dir)
39
-
40
- from main.configs.config import Config
41
- from main.library.predictors.FCPE import FCPE
42
- from main.library.predictors.RMVPE import RMVPE
43
- from main.library.algorithm.synthesizers import Synthesizer
44
-
45
-
46
- warnings.filterwarnings("ignore", category=FutureWarning)
47
- warnings.filterwarnings("ignore", category=UserWarning)
48
-
49
- logging.getLogger("wget").setLevel(logging.ERROR)
50
- logging.getLogger("torch").setLevel(logging.ERROR)
51
- logging.getLogger("faiss").setLevel(logging.ERROR)
52
- logging.getLogger("httpx").setLevel(logging.ERROR)
53
- logging.getLogger("fairseq").setLevel(logging.ERROR)
54
- logging.getLogger("httpcore").setLevel(logging.ERROR)
55
- logging.getLogger("faiss.loader").setLevel(logging.ERROR)
56
-
57
-
58
- FILTER_ORDER = 5
59
- CUTOFF_FREQUENCY = 48
60
- SAMPLE_RATE = 16000
61
-
62
- bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE)
63
- input_audio_path2wav = {}
64
-
65
- log_file = os.path.join("assets", "logs", "convert.log")
66
-
67
- logger = logging.getLogger(__name__)
68
- logger.propagate = False
69
-
70
- translations = Config().translations
71
-
72
-
73
- if logger.hasHandlers(): logger.handlers.clear()
74
- else:
75
- console_handler = logging.StreamHandler()
76
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
77
-
78
- console_handler.setFormatter(console_formatter)
79
- console_handler.setLevel(logging.INFO)
80
-
81
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
82
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
83
-
84
- file_handler.setFormatter(file_formatter)
85
- file_handler.setLevel(logging.DEBUG)
86
-
87
- logger.addHandler(console_handler)
88
- logger.addHandler(file_handler)
89
- logger.setLevel(logging.DEBUG)
90
-
91
-
92
- def parse_arguments() -> tuple:
93
- parser = argparse.ArgumentParser()
94
- parser.add_argument("--pitch", type=int, default=0)
95
- parser.add_argument("--filter_radius", type=int, default=3)
96
- parser.add_argument("--index_rate", type=float, default=0.5)
97
- parser.add_argument("--volume_envelope", type=float, default=1)
98
- parser.add_argument("--protect", type=float, default=0.33)
99
- parser.add_argument("--hop_length", type=int, default=64)
100
- parser.add_argument( "--f0_method", type=str, default="rmvpe")
101
- parser.add_argument("--input_path", type=str, required=True)
102
- parser.add_argument("--output_path", type=str, default="./audios/output.wav")
103
- parser.add_argument("--pth_path", type=str, required=True)
104
- parser.add_argument("--index_path", type=str, required=True)
105
- parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
106
- parser.add_argument("--f0_autotune_strength", type=float, default=1)
107
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
108
- parser.add_argument("--clean_strength", type=float, default=0.7)
109
- parser.add_argument("--export_format", type=str, default="wav")
110
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
111
- parser.add_argument("--upscale_audio", type=lambda x: bool(strtobool(x)), default=False)
112
- parser.add_argument("--resample_sr", type=int, default=0)
113
- parser.add_argument("--batch_process", type=lambda x: bool(strtobool(x)), default=False)
114
- parser.add_argument("--batch_size", type=int, default=2)
115
- parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
116
-
117
- args = parser.parse_args()
118
- return args
119
-
120
-
121
- def main():
122
- args = parse_arguments()
123
- pitch = args.pitch
124
- filter_radius = args.filter_radius
125
- index_rate = args.index_rate
126
- volume_envelope = args.volume_envelope
127
- protect = args.protect
128
- hop_length = args.hop_length
129
- f0_method = args.f0_method
130
- input_path = args.input_path
131
- output_path = args.output_path
132
- pth_path = args.pth_path
133
- index_path = args.index_path
134
- f0_autotune = args.f0_autotune
135
- f0_autotune_strength = args.f0_autotune_strength
136
- clean_audio = args.clean_audio
137
- clean_strength = args.clean_strength
138
- export_format = args.export_format
139
- embedder_model = args.embedder_model
140
- upscale_audio = args.upscale_audio
141
- resample_sr = args.resample_sr
142
- batch_process = args.batch_process
143
- batch_size = args.batch_size
144
- split_audio = args.split_audio
145
-
146
- logger.debug(f"{translations['pitch']}: {pitch}")
147
- logger.debug(f"{translations['filter_radius']}: {filter_radius}")
148
- logger.debug(f"{translations['index_strength']} {index_rate}")
149
- logger.debug(f"{translations['volume_envelope']}: {volume_envelope}")
150
- logger.debug(f"{translations['protect']}: {protect}")
151
- if f0_method == "crepe" or f0_method == "crepe-tiny": logger.debug(f"Hop length: {hop_length}")
152
- logger.debug(f"{translations['f0_method']}: {f0_method}")
153
- logger.debug(f"f0_method: {input_path}")
154
- logger.debug(f"{translations['audio_path']}: {input_path}")
155
- logger.debug(f"{translations['output_path']}: {output_path.replace('.wav', f'.{export_format}')}")
156
- logger.debug(f"{translations['model_path']}: {pth_path}")
157
- logger.debug(f"{translations['indexpath']}: {index_path}")
158
- logger.debug(f"{translations['autotune']}: {f0_autotune}")
159
- logger.debug(f"{translations['clear_audio']}: {clean_audio}")
160
- if clean_audio: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
161
- logger.debug(f"{translations['export_format']}: {export_format}")
162
- logger.debug(f"{translations['hubert_model']}: {embedder_model}")
163
- logger.debug(f"{translations['upscale_audio']}: {upscale_audio}")
164
- if resample_sr != 0: logger.debug(f"{translations['sample_rate']}: {resample_sr}")
165
- if split_audio: logger.debug(f"{translations['batch_process']}: {batch_process}")
166
- if batch_process and split_audio: logger.debug(f"{translations['batch_size']}: {batch_size}")
167
- logger.debug(f"{translations['split_audio']}: {split_audio}")
168
- if f0_autotune: logger.debug(f"{translations['autotune_rate_info']}: {f0_autotune_strength}")
169
-
170
-
171
- check_rmvpe_fcpe(f0_method)
172
- check_hubert(embedder_model)
173
-
174
- run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, upscale_audio=upscale_audio, resample_sr=resample_sr, batch_process=batch_process, batch_size=batch_size, split_audio=split_audio)
175
-
176
-
177
- def check_rmvpe_fcpe(method):
178
- def download_rmvpe():
179
- if not os.path.exists(os.path.join("assets", "model", "predictors", "rmvpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "rmvpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
180
-
181
- def download_fcpe():
182
- if not os.path.exists(os.path.join("assets", "model", "predictors", "fcpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "fcpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
183
-
184
- if method == "rmvpe": download_rmvpe()
185
- elif method == "fcpe": download_fcpe()
186
- elif "hybrid" in method:
187
- methods_str = re.search("hybrid\[(.+)\]", method)
188
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
189
-
190
- for method in methods:
191
- if method == "rmvpe": download_rmvpe()
192
- elif method == "fcpe": download_fcpe()
193
-
194
-
195
- def check_hubert(hubert):
196
- if hubert == "contentvec_base" or hubert == "hubert_base" or hubert == "japanese_hubert_base" or hubert == "korean_hubert_base" or hubert == "chinese_hubert_base":
197
- model_path = os.path.join(now_dir, "assets", "model", "embedders", hubert + '.pt')
198
-
199
- if not os.path.exists(model_path): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + f"{hubert}.pt", "-P", os.path.join("assets", "model", "embedders")], check=True)
200
-
201
-
202
- def load_audio_infer(file, sample_rate):
203
- try:
204
- file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
205
- if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
206
-
207
- audio, sr = sf.read(file)
208
-
209
- if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
210
- if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
211
- except Exception as e:
212
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
213
-
214
- return audio.flatten()
215
-
216
-
217
- def process_audio(file_path, output_path):
218
- try:
219
- song = AudioSegment.from_file(file_path)
220
- nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)
221
-
222
- cut_files = []
223
- time_stamps = []
224
-
225
- min_chunk_duration = 30
226
-
227
- for i, (start_i, end_i) in enumerate(nonsilent_parts):
228
- chunk = song[start_i:end_i]
229
-
230
- if len(chunk) >= min_chunk_duration:
231
- chunk_file_path = os.path.join(output_path, f"chunk{i}.wav")
232
-
233
- if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
234
- chunk.export(chunk_file_path, format="wav")
235
-
236
- cut_files.append(chunk_file_path)
237
- time_stamps.append((start_i, end_i))
238
- else: logger.debug(translations["skip_file"].format(i=i, chunk=len(chunk)))
239
-
240
- logger.info(f"{translations['split_total']}: {len(cut_files)}")
241
- return cut_files, time_stamps
242
- except Exception as e:
243
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
244
-
245
-
246
- def merge_audio(files_list, time_stamps, original_file_path, output_path, format):
247
- try:
248
- def extract_number(filename):
249
- match = re.search(r'_(\d+)', filename)
250
-
251
- return int(match.group(1)) if match else 0
252
-
253
- files_list = sorted(files_list, key=extract_number)
254
- total_duration = len(AudioSegment.from_file(original_file_path))
255
-
256
- combined = AudioSegment.empty()
257
- current_position = 0
258
-
259
- for file, (start_i, end_i) in zip(files_list, time_stamps):
260
- if start_i > current_position:
261
- silence_duration = start_i - current_position
262
- combined += AudioSegment.silent(duration=silence_duration)
263
-
264
- combined += AudioSegment.from_file(file)
265
- current_position = end_i
266
-
267
- if current_position < total_duration: combined += AudioSegment.silent(duration=total_duration - current_position)
268
-
269
- combined.export(output_path, format=format)
270
- return output_path
271
- except Exception as e:
272
- raise RuntimeError(f"{translations['merge_error']}: {e}")
273
-
274
-
275
- def run_batch_convert(params):
276
- cvt = VoiceConverter()
277
-
278
- path = params["path"]
279
- audio_temp = params["audio_temp"]
280
- export_format = params["export_format"]
281
- cut_files = params["cut_files"]
282
- pitch = params["pitch"]
283
- filter_radius = params["filter_radius"]
284
- index_rate = params["index_rate"]
285
- volume_envelope = params["volume_envelope"]
286
- protect = params["protect"]
287
- hop_length = params["hop_length"]
288
- f0_method = params["f0_method"]
289
- pth_path = params["pth_path"]
290
- index_path = params["index_path"]
291
- f0_autotune = params["f0_autotune"]
292
- f0_autotune_strength = params["f0_autotune_strength"]
293
- clean_audio = params["clean_audio"]
294
- clean_strength = params["clean_strength"]
295
- upscale_audio = params["upscale_audio"]
296
- embedder_model = params["embedder_model"]
297
- resample_sr = params["resample_sr"]
298
-
299
-
300
- segment_output_path = os.path.join(audio_temp, f"output_{cut_files.index(path)}.{export_format}")
301
- if os.path.exists(segment_output_path): os.remove(segment_output_path)
302
-
303
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=path, audio_output_path=segment_output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
304
- os.remove(path)
305
-
306
-
307
- if os.path.exists(segment_output_path): return segment_output_path
308
- else:
309
- logger.warning(f"{translations['not_found_convert_file']}: {segment_output_path}")
310
- sys.exit(1)
311
-
312
-
313
- def run_convert_script(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, upscale_audio, embedder_model, resample_sr, batch_process, batch_size, split_audio):
314
- cvt = VoiceConverter()
315
- start_time = time.time()
316
-
317
-
318
- if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith(".pth"):
319
- logger.warning(translations["provide_file"].format(filename=translations["model"]))
320
- sys.exit(1)
321
-
322
- if not index_path or not os.path.exists(index_path) or os.path.isdir(index_path) or not index_path.endswith(".index"):
323
- logger.warning(translations["provide_file"].format(filename=translations["index"]))
324
- sys.exit(1)
325
-
326
-
327
- output_dir = os.path.dirname(output_path)
328
- output_dir = output_path if not output_dir else output_dir
329
-
330
- if output_dir is None: output_dir = "audios"
331
-
332
- if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
333
-
334
- audio_temp = os.path.join("audios_temp")
335
- if not os.path.exists(audio_temp) and split_audio: os.makedirs(audio_temp, exist_ok=True)
336
-
337
- processed_segments = []
338
-
339
- if os.path.isdir(input_path):
340
- try:
341
- logger.info(translations["convert_batch"])
342
-
343
- audio_files = [f for f in os.listdir(input_path) if f.endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
344
- if not audio_files:
345
- logger.warning(translations["not_found_audio"])
346
- sys.exit(1)
347
-
348
- logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
349
-
350
- for audio in audio_files:
351
- audio_path = os.path.join(input_path, audio)
352
- output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
353
-
354
- if split_audio:
355
- try:
356
- cut_files, time_stamps = process_audio(audio_path, audio_temp)
357
- num_threads = min(batch_size, len(cut_files))
358
-
359
- params_list = [
360
- {
361
- "path": path,
362
- "audio_temp": audio_temp,
363
- "export_format": export_format,
364
- "cut_files": cut_files,
365
- "pitch": pitch,
366
- "filter_radius": filter_radius,
367
- "index_rate": index_rate,
368
- "volume_envelope": volume_envelope,
369
- "protect": protect,
370
- "hop_length": hop_length,
371
- "f0_method": f0_method,
372
- "pth_path": pth_path,
373
- "index_path": index_path,
374
- "f0_autotune": f0_autotune,
375
- "f0_autotune_strength": f0_autotune_strength,
376
- "clean_audio": clean_audio,
377
- "clean_strength": clean_strength,
378
- "upscale_audio": upscale_audio,
379
- "embedder_model": embedder_model,
380
- "resample_sr": resample_sr
381
- }
382
- for path in cut_files
383
- ]
384
-
385
- if batch_process:
386
- with mp.Pool(processes=num_threads) as pool:
387
- with tqdm(total=len(params_list), desc=translations["convert_audio"]) as pbar:
388
- for results in pool.imap_unordered(run_batch_convert, params_list):
389
- processed_segments.append(results)
390
- pbar.update(1)
391
- else:
392
- for params in tqdm(params_list, desc=translations["convert_audio"]):
393
- run_batch_convert(params)
394
-
395
- merge_audio(processed_segments, time_stamps, audio_path, output_audio, export_format)
396
- except Exception as e:
397
- logger.error(translations["error_convert_batch"].format(e=e))
398
- finally:
399
- if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
400
- else:
401
- try:
402
- logger.info(f"{translations['convert_audio']} '{audio_path}'...")
403
-
404
- if os.path.exists(output_audio): os.remove(output_audio)
405
-
406
- with tqdm(total=1, desc=translations["convert_audio"]) as pbar:
407
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
408
- pbar.update(1)
409
- except Exception as e:
410
- logger.error(translations["error_convert"].format(e=e))
411
-
412
- elapsed_time = time.time() - start_time
413
- logger.info(translations["convert_batch_success"].format(elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('.wav', f'.{export_format}')))
414
- except Exception as e:
415
- logger.error(translations["error_convert_batch_2"].format(e=e))
416
- else:
417
- logger.info(f"{translations['convert_audio']} '{input_path}'...")
418
-
419
- if not os.path.exists(input_path):
420
- logger.warning(translations["not_found_audio"])
421
- sys.exit(1)
422
-
423
- if os.path.isdir(output_path): output_path = os.path.join(output_path, f"output.{export_format}")
424
- if os.path.exists(output_path): os.remove(output_path)
425
-
426
- if split_audio:
427
- try:
428
- cut_files, time_stamps = process_audio(input_path, audio_temp)
429
- num_threads = min(batch_size, len(cut_files))
430
-
431
- params_list = [
432
- {
433
- "path": path,
434
- "audio_temp": audio_temp,
435
- "export_format": export_format,
436
- "cut_files": cut_files,
437
- "pitch": pitch,
438
- "filter_radius": filter_radius,
439
- "index_rate": index_rate,
440
- "volume_envelope": volume_envelope,
441
- "protect": protect,
442
- "hop_length": hop_length,
443
- "f0_method": f0_method,
444
- "pth_path": pth_path,
445
- "index_path": index_path,
446
- "f0_autotune": f0_autotune,
447
- "f0_autotune_strength": f0_autotune_strength,
448
- "clean_audio": clean_audio,
449
- "clean_strength": clean_strength,
450
- "upscale_audio": upscale_audio,
451
- "embedder_model": embedder_model,
452
- "resample_sr": resample_sr
453
- }
454
- for path in cut_files
455
- ]
456
-
457
- if batch_process:
458
- with mp.Pool(processes=num_threads) as pool:
459
- with tqdm(total=len(params_list), desc=translations["convert_audio"]) as pbar:
460
- for results in pool.imap_unordered(run_batch_convert, params_list):
461
- processed_segments.append(results)
462
- pbar.update(1)
463
- else:
464
- for params in tqdm(params_list, desc=translations["convert_audio"]):
465
- run_batch_convert(params)
466
-
467
- merge_audio(processed_segments, time_stamps, input_path, output_path.replace(".wav", f".{export_format}"), export_format)
468
- except Exception as e:
469
- logger.error(translations["error_convert_batch"].format(e=e))
470
- finally:
471
- if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
472
- else:
473
- try:
474
- with tqdm(total=1, desc=translations["convert_audio"]) as pbar:
475
- cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
476
- pbar.update(1)
477
- except Exception as e:
478
- logger.error(translations["error_convert"].format(e=e))
479
-
480
- elapsed_time = time.time() - start_time
481
- logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('.wav', f'.{export_format}')))
482
-
483
-
484
- def change_rms(source_audio: np.ndarray, source_rate: int, target_audio: np.ndarray, target_rate: int, rate: float) -> np.ndarray:
485
- rms1 = librosa.feature.rms(
486
- y=source_audio,
487
- frame_length=source_rate // 2 * 2,
488
- hop_length=source_rate // 2,
489
- )
490
-
491
- rms2 = librosa.feature.rms(
492
- y=target_audio,
493
- frame_length=target_rate // 2 * 2,
494
- hop_length=target_rate // 2,
495
- )
496
-
497
- rms1 = F.interpolate(
498
- torch.from_numpy(rms1).float().unsqueeze(0),
499
- size=target_audio.shape[0],
500
- mode="linear",
501
- ).squeeze()
502
-
503
- rms2 = F.interpolate(
504
- torch.from_numpy(rms2).float().unsqueeze(0),
505
- size=target_audio.shape[0],
506
- mode="linear",
507
- ).squeeze()
508
-
509
- rms2 = torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6)
510
-
511
-
512
- adjusted_audio = (target_audio * (torch.pow(rms1, 1 - rate) * torch.pow(rms2, rate - 1)).numpy())
513
- return adjusted_audio
514
-
515
-
516
- class Autotune:
517
- def __init__(self, ref_freqs):
518
- self.ref_freqs = ref_freqs
519
- self.note_dict = self.ref_freqs
520
-
521
-
522
- def autotune_f0(self, f0, f0_autotune_strength):
523
- autotuned_f0 = np.zeros_like(f0)
524
-
525
-
526
- for i, freq in enumerate(f0):
527
- closest_note = min(self.note_dict, key=lambda x: abs(x - freq))
528
- autotuned_f0[i] = freq + (closest_note - freq) * f0_autotune_strength
529
-
530
- return autotuned_f0
531
-
532
-
533
- class VC:
534
- def __init__(self, tgt_sr, config):
535
- self.x_pad = config.x_pad
536
- self.x_query = config.x_query
537
- self.x_center = config.x_center
538
- self.x_max = config.x_max
539
- self.is_half = config.is_half
540
- self.sample_rate = 16000
541
- self.window = 160
542
- self.t_pad = self.sample_rate * self.x_pad
543
- self.t_pad_tgt = tgt_sr * self.x_pad
544
- self.t_pad2 = self.t_pad * 2
545
- self.t_query = self.sample_rate * self.x_query
546
- self.t_center = self.sample_rate * self.x_center
547
- self.t_max = self.sample_rate * self.x_max
548
- self.time_step = self.window / self.sample_rate * 1000
549
- self.f0_min = 50
550
- self.f0_max = 1100
551
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
552
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
553
- self.device = config.device
554
- self.ref_freqs = [
555
- 49.00,
556
- 51.91,
557
- 55.00,
558
- 58.27,
559
- 61.74,
560
- 65.41,
561
- 69.30,
562
- 73.42,
563
- 77.78,
564
- 82.41,
565
- 87.31,
566
- 92.50,
567
- 98.00,
568
- 103.83,
569
- 110.00,
570
- 116.54,
571
- 123.47,
572
- 130.81,
573
- 138.59,
574
- 146.83,
575
- 155.56,
576
- 164.81,
577
- 174.61,
578
- 185.00,
579
- 196.00,
580
- 207.65,
581
- 220.00,
582
- 233.08,
583
- 246.94,
584
- 261.63,
585
- 277.18,
586
- 293.66,
587
- 311.13,
588
- 329.63,
589
- 349.23,
590
- 369.99,
591
- 392.00,
592
- 415.30,
593
- 440.00,
594
- 466.16,
595
- 493.88,
596
- 523.25,
597
- 554.37,
598
- 587.33,
599
- 622.25,
600
- 659.25,
601
- 698.46,
602
- 739.99,
603
- 783.99,
604
- 830.61,
605
- 880.00,
606
- 932.33,
607
- 987.77,
608
- 1046.50
609
- ]
610
- self.autotune = Autotune(self.ref_freqs)
611
- self.note_dict = self.autotune.note_dict
612
-
613
-
614
- def get_f0_crepe(self, x, f0_min, f0_max, p_len, hop_length, model="full"):
615
- x = x.astype(np.float32)
616
- x /= np.quantile(np.abs(x), 0.999)
617
-
618
- audio = torch.from_numpy(x).to(self.device, copy=True)
619
- audio = torch.unsqueeze(audio, dim=0)
620
-
621
-
622
- if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
623
-
624
- audio = audio.detach()
625
- pitch: Tensor = torchcrepe.predict(audio, self.sample_rate, hop_length, f0_min, f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True)
626
-
627
- p_len = p_len or x.shape[0] // hop_length
628
- source = np.array(pitch.squeeze(0).cpu().float().numpy())
629
- source[source < 0.001] = np.nan
630
-
631
- target = np.interp(
632
- np.arange(0, len(source) * p_len, len(source)) / p_len,
633
- np.arange(0, len(source)),
634
- source,
635
- )
636
-
637
- f0 = np.nan_to_num(target)
638
- return f0
639
-
640
-
641
- def get_f0_hybrid(self, methods_str, x, f0_min, f0_max, p_len, hop_length, filter_radius):
642
- methods_str = re.search("hybrid\[(.+)\]", methods_str)
643
- if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
644
-
645
- f0_computation_stack = []
646
- logger.debug(translations["hybrid_methods"].format(methods=methods))
647
-
648
- x = x.astype(np.float32)
649
- x /= np.quantile(np.abs(x), 0.999)
650
-
651
-
652
- for method in methods:
653
- f0 = None
654
-
655
-
656
- if method == "pm":
657
- f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
658
- pad_size = (p_len - len(f0) + 1) // 2
659
-
660
- if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
661
- elif method == 'dio':
662
- f0, t = pyworld.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
663
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
664
-
665
- f0 = signal.medfilt(f0, 3)
666
- elif method == "crepe-tiny":
667
- f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length), "tiny")
668
- elif method == "crepe":
669
- f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length))
670
- elif method == "fcpe":
671
- self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=int(f0_min), f0_max=int(f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03)
672
- f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
673
-
674
- del self.model_fcpe
675
- gc.collect()
676
- elif method == "rmvpe":
677
- f0 = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=self.is_half, device=self.device).infer_from_audio(x, thred=0.03)
678
- f0 = f0[1:]
679
- elif method == "harvest":
680
- f0, t = pyworld.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
681
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
682
-
683
- if filter_radius > 2: f0 = signal.medfilt(f0, 3)
684
- else: raise ValueError(translations["method_not_valid"])
685
-
686
- f0_computation_stack.append(f0)
687
-
688
- resampled_stack = []
689
-
690
- for f0 in f0_computation_stack:
691
- resampled_f0 = np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0)
692
- resampled_stack.append(resampled_f0)
693
-
694
- f0_median_hybrid = resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
695
- return f0_median_hybrid
696
-
697
-
698
- def get_f0(self, input_audio_path, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength):
699
- global input_audio_path2wav
700
-
701
-
702
- if f0_method == "pm":
703
- f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
704
- pad_size = (p_len - len(f0) + 1) // 2
705
-
706
- if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
707
- elif f0_method == "dio":
708
- f0, t = pyworld.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
709
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
710
-
711
- f0 = signal.medfilt(f0, 3)
712
- elif f0_method == "crepe-tiny":
713
- f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length), "tiny")
714
- elif f0_method == "crepe":
715
- f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length))
716
- elif f0_method == "fcpe":
717
- self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03)
718
- f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
719
-
720
- del self.model_fcpe
721
- gc.collect()
722
- elif f0_method == "rmvpe":
723
- f0 = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=self.is_half, device=self.device).infer_from_audio(x, thred=0.03)
724
- elif f0_method == "harvest":
725
- f0, t = pyworld.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
726
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
727
-
728
- if filter_radius > 2: f0 = signal.medfilt(f0, 3)
729
- elif "hybrid" in f0_method:
730
- input_audio_path2wav[input_audio_path] = x.astype(np.double)
731
- f0 = self.get_f0_hybrid(f0_method, x, self.f0_min, self.f0_max, p_len, hop_length, filter_radius)
732
- else: raise ValueError(translations["method_not_valid"])
733
-
734
- if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
735
-
736
- f0 *= pow(2, pitch / 12)
737
-
738
- f0bak = f0.copy()
739
-
740
- f0_mel = 1127 * np.log(1 + f0 / 700)
741
- f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
742
- f0_mel[f0_mel <= 1] = 1
743
- f0_mel[f0_mel > 255] = 255
744
-
745
- f0_coarse = np.rint(f0_mel).astype(np.int32)
746
- return f0_coarse, f0bak
747
-
748
-
749
- def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
750
- pitch_guidance = pitch != None and pitchf != None
751
-
752
- feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
753
-
754
- if feats.dim() == 2: feats = feats.mean(-1)
755
- assert feats.dim() == 1, feats.dim()
756
-
757
- feats = feats.view(1, -1)
758
-
759
- padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
760
-
761
- inputs = {
762
- "source": feats.to(self.device),
763
- "padding_mask": padding_mask,
764
- "output_layer": 9 if version == "v1" else 12,
765
- }
766
-
767
- with torch.no_grad():
768
- logits = model.extract_features(**inputs)
769
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
770
-
771
- if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
772
-
773
- if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
774
- npy = feats[0].cpu().numpy()
775
-
776
- if self.is_half: npy = npy.astype("float32")
777
-
778
- score, ix = index.search(npy, k=8)
779
-
780
- weight = np.square(1 / score)
781
- weight /= weight.sum(axis=1, keepdims=True)
782
-
783
- npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
784
-
785
- if self.is_half: npy = npy.astype("float16")
786
-
787
- feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
788
-
789
- feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
790
-
791
- if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
792
-
793
- p_len = audio0.shape[0] // self.window
794
-
795
- if feats.shape[1] < p_len:
796
- p_len = feats.shape[1]
797
-
798
- if pitch_guidance:
799
- pitch = pitch[:, :p_len]
800
- pitchf = pitchf[:, :p_len]
801
-
802
- if protect < 0.5 and pitch_guidance:
803
- pitchff = pitchf.clone()
804
- pitchff[pitchf > 0] = 1
805
- pitchff[pitchf < 1] = protect
806
- pitchff = pitchff.unsqueeze(-1)
807
-
808
- feats = feats * pitchff + feats0 * (1 - pitchff)
809
- feats = feats.to(feats0.dtype)
810
-
811
- p_len = torch.tensor([p_len], device=self.device).long()
812
-
813
- with torch.no_grad():
814
- audio1 = ((net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]).data.cpu().float().numpy()) if pitch_guidance else ((net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy())
815
-
816
- del feats, p_len, padding_mask
817
-
818
- if torch.cuda.is_available(): torch.cuda.empty_cache()
819
- return audio1
820
-
821
-
822
- def pipeline(self, model, net_g, sid, audio, input_audio_path, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, tgt_sr, resample_sr, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength):
823
- if file_index != "" and os.path.exists(file_index) and index_rate != 0:
824
- try:
825
- index = faiss.read_index(file_index)
826
- big_npy = index.reconstruct_n(0, index.ntotal)
827
- except Exception as e:
828
- logger.error(translations["read_faiss_index_error"].format(e=e))
829
- index = big_npy = None
830
- else: index = big_npy = None
831
-
832
- audio = signal.filtfilt(bh, ah, audio)
833
- audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
834
- opt_ts = []
835
-
836
- if audio_pad.shape[0] > self.t_max:
837
- audio_sum = np.zeros_like(audio)
838
-
839
- for i in range(self.window):
840
- audio_sum += audio_pad[i : i - self.window]
841
-
842
- for t in range(self.t_center, audio.shape[0], self.t_center):
843
- opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
844
-
845
- s = 0
846
- audio_opt = []
847
- t = None
848
-
849
- audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
850
- p_len = audio_pad.shape[0] // self.window
851
-
852
- sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
853
-
854
- if pitch_guidance:
855
- pitch, pitchf = self.get_f0(input_audio_path, audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength)
856
- pitch = pitch[:p_len]
857
- pitchf = pitchf[:p_len]
858
-
859
- if self.device == "mps": pitchf = pitchf.astype(np.float32)
860
-
861
- pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
862
- pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
863
-
864
- for t in opt_ts:
865
- t = t // self.window * self.window
866
-
867
- if pitch_guidance: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window], pitchf[:, s // self.window : (t + self.t_pad2) // self.window], index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
868
- else: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], None, None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
869
-
870
- s = t
871
-
872
- if pitch_guidance: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], pitch[:, t // self.window :] if t is not None else pitch, pitchf[:, t // self.window :] if t is not None else pitchf, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
873
- else: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], None, None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
874
-
875
- audio_opt = np.concatenate(audio_opt)
876
-
877
- if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope)
878
- if resample_sr >= self.sample_rate and tgt_sr != resample_sr: audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr)
879
-
880
- audio_max = np.abs(audio_opt).max() / 0.99
881
- max_int16 = 32768
882
-
883
- if audio_max > 1: max_int16 /= audio_max
884
-
885
- audio_opt = (audio_opt * max_int16).astype(np.int16)
886
-
887
- if pitch_guidance: del pitch, pitchf
888
- del sid
889
-
890
- if torch.cuda.is_available(): torch.cuda.empty_cache()
891
- return audio_opt
892
-
893
-
894
- class VoiceConverter:
895
- def __init__(self):
896
- self.config = Config()
897
- self.hubert_model = (None)
898
-
899
- self.tgt_sr = None
900
- self.net_g = None
901
-
902
- self.vc = None
903
- self.cpt = None
904
-
905
- self.version = None
906
- self.n_spk = None
907
-
908
- self.use_f0 = None
909
- self.loaded_model = None
910
-
911
-
912
- def load_hubert(self, embedder_model):
913
- try:
914
- models, _, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')], suffix="")
915
- except Exception as e:
916
- raise ImportError(translations["read_model_error"].format(e=e))
917
-
918
- self.hubert_model = models[0].to(self.config.device)
919
- self.hubert_model = (self.hubert_model.half() if self.config.is_half else self.hubert_model.float())
920
- self.hubert_model.eval()
921
-
922
-
923
- @staticmethod
924
- def remove_audio_noise(input_audio_path, reduction_strength=0.7):
925
- try:
926
- rate, data = wavfile.read(input_audio_path)
927
- reduced_noise = nr.reduce_noise(y=data, sr=rate, prop_decrease=reduction_strength)
928
-
929
- return reduced_noise
930
- except Exception as e:
931
- logger.error(translations["denoise_error"].format(e=e))
932
- return None
933
-
934
-
935
- @staticmethod
936
- def convert_audio_format(input_path, output_path, output_format):
937
- try:
938
- if output_format != "wav":
939
- logger.debug(translations["change_format"].format(output_format=output_format))
940
- audio, sample_rate = sf.read(input_path)
941
-
942
-
943
- common_sample_rates = [
944
- 8000,
945
- 11025,
946
- 12000,
947
- 16000,
948
- 22050,
949
- 24000,
950
- 32000,
951
- 44100,
952
- 48000
953
- ]
954
-
955
- target_sr = min(common_sample_rates, key=lambda x: abs(x - sample_rate))
956
- audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sr)
957
-
958
- sf.write(output_path, audio, target_sr, format=output_format)
959
-
960
- return output_path
961
- except Exception as e:
962
- raise RuntimeError(translations["change_format_error"].format(e=e))
963
-
964
-
965
- def convert_audio(self, audio_input_path, audio_output_path, model_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, upscale_audio, resample_sr = 0, sid = 0):
966
- self.get_vc(model_path, sid)
967
-
968
- try:
969
- if upscale_audio: upscale(audio_input_path, audio_input_path)
970
-
971
- audio = load_audio_infer(audio_input_path, 16000)
972
-
973
- audio_max = np.abs(audio).max() / 0.95
974
-
975
-
976
- if audio_max > 1: audio /= audio_max
977
-
978
- if not self.hubert_model:
979
- if not os.path.exists(os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')): raise FileNotFoundError(f"Không tìm thấy mô hình: {embedder_model}")
980
-
981
- self.load_hubert(embedder_model)
982
-
983
- if self.tgt_sr != resample_sr >= 16000: self.tgt_sr = resample_sr
984
-
985
- file_index = (index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added"))
986
-
987
- audio_opt = self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=sid, audio=audio, input_audio_path=audio_input_path, pitch=pitch, f0_method=f0_method, file_index=file_index, index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, tgt_sr=self.tgt_sr, resample_sr=resample_sr, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength)
988
-
989
- if audio_output_path: sf.write(audio_output_path, audio_opt, self.tgt_sr, format="wav")
990
-
991
- if clean_audio:
992
- cleaned_audio = self.remove_audio_noise(audio_output_path, clean_strength)
993
- if cleaned_audio is not None: sf.write(audio_output_path, cleaned_audio, self.tgt_sr, format="wav")
994
-
995
- output_path_format = audio_output_path.replace(".wav", f".{export_format}")
996
- audio_output_path = self.convert_audio_format(audio_output_path, output_path_format, export_format)
997
- except Exception as e:
998
- logger.error(translations["error_convert"].format(e=e))
999
- logger.error(traceback.format_exc())
1000
-
1001
-
1002
- def get_vc(self, weight_root, sid):
1003
- if sid == "" or sid == []:
1004
- self.cleanup_model()
1005
- if torch.cuda.is_available(): torch.cuda.empty_cache()
1006
-
1007
- if not self.loaded_model or self.loaded_model != weight_root:
1008
- self.load_model(weight_root)
1009
-
1010
- if self.cpt is not None:
1011
- self.setup_network()
1012
- self.setup_vc_instance()
1013
-
1014
- self.loaded_model = weight_root
1015
-
1016
-
1017
- def cleanup_model(self):
1018
- if self.hubert_model is not None:
1019
- del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
1020
-
1021
- self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
1022
-
1023
- if torch.cuda.is_available(): torch.cuda.empty_cache()
1024
-
1025
- del self.net_g, self.cpt
1026
-
1027
- if torch.cuda.is_available(): torch.cuda.empty_cache()
1028
- self.cpt = None
1029
-
1030
-
1031
- def load_model(self, weight_root):
1032
- self.cpt = (torch.load(weight_root, map_location="cpu") if os.path.isfile(weight_root) else None)
1033
-
1034
-
1035
- def setup_network(self):
1036
- if self.cpt is not None:
1037
- self.tgt_sr = self.cpt["config"][-1]
1038
- self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
1039
- self.use_f0 = self.cpt.get("f0", 1)
1040
-
1041
- self.version = self.cpt.get("version", "v1")
1042
- self.text_enc_hidden_dim = 768 if self.version == "v2" else 256
1043
-
1044
- self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=self.text_enc_hidden_dim, is_half=self.config.is_half)
1045
-
1046
- del self.net_g.enc_q
1047
-
1048
- self.net_g.load_state_dict(self.cpt["weight"], strict=False)
1049
- self.net_g.eval().to(self.config.device)
1050
- self.net_g = (self.net_g.half() if self.config.is_half else self.net_g.float())
1051
-
1052
-
1053
- def setup_vc_instance(self):
1054
- if self.cpt is not None:
1055
- self.vc = VC(self.tgt_sr, self.config)
1056
- self.n_spk = self.cpt["config"][-3]
1057
-
1058
- if __name__ == "__main__":
1059
- mp.set_start_method("spawn", force=True)
1060
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_dataset.py DELETED
@@ -1,370 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import time
5
- import yt_dlp
6
- import shutil
7
- import librosa
8
- import logging
9
- import argparse
10
- import warnings
11
- import logging.handlers
12
-
13
- import soundfile as sf
14
- import noisereduce as nr
15
-
16
- from distutils.util import strtobool
17
- from pydub import AudioSegment, silence
18
-
19
-
20
- now_dir = os.getcwd()
21
- sys.path.append(now_dir)
22
-
23
- from main.configs.config import Config
24
- from main.library.algorithm.separator import Separator
25
-
26
-
27
- translations = Config().translations
28
-
29
-
30
- log_file = os.path.join("assets", "logs", "create_dataset.log")
31
- logger = logging.getLogger(__name__)
32
-
33
- if logger.hasHandlers(): logger.handlers.clear()
34
- else:
35
- console_handler = logging.StreamHandler()
36
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
37
-
38
- console_handler.setFormatter(console_formatter)
39
- console_handler.setLevel(logging.INFO)
40
-
41
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
42
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
43
-
44
- file_handler.setFormatter(file_formatter)
45
- file_handler.setLevel(logging.DEBUG)
46
-
47
- logger.addHandler(console_handler)
48
- logger.addHandler(file_handler)
49
- logger.setLevel(logging.DEBUG)
50
-
51
-
52
- def parse_arguments() -> tuple:
53
- parser = argparse.ArgumentParser()
54
- parser.add_argument("--input_audio", type=str, required=True)
55
- parser.add_argument("--output_dataset", type=str, default="./dataset")
56
- parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
57
- parser.add_argument("--resample_sr", type=int, default=44100)
58
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
59
- parser.add_argument("--clean_strength", type=float, default=0.7)
60
- parser.add_argument("--separator_music", type=lambda x: bool(strtobool(x)), default=False)
61
- parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
62
- parser.add_argument("--kim_vocal_version", type=int, default=2)
63
- parser.add_argument("--overlap", type=float, default=0.25)
64
- parser.add_argument("--segments_size", type=int, default=256)
65
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
66
- parser.add_argument("--mdx_batch_size", type=int, default=1)
67
- parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
68
- parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
69
- parser.add_argument("--skip_start_audios", type=str, default="0")
70
- parser.add_argument("--skip_end_audios", type=str, default="0")
71
-
72
- args = parser.parse_args()
73
- return args
74
-
75
-
76
- dataset_temp = os.path.join("dataset_temp")
77
-
78
-
79
- def main():
80
- args = parse_arguments()
81
- input_audio = args.input_audio
82
- output_dataset = args.output_dataset
83
- resample = args.resample
84
- resample_sr = args.resample_sr
85
- clean_dataset = args.clean_dataset
86
- clean_strength = args.clean_strength
87
- separator_music = args.separator_music
88
- separator_reverb = args.separator_reverb
89
- kim_vocal_version = args.kim_vocal_version
90
- overlap = args.overlap
91
- segments_size = args.segments_size
92
- hop_length = args.mdx_hop_length
93
- batch_size = args.mdx_batch_size
94
- denoise_mdx = args.denoise_mdx
95
- skip = args.skip
96
- skip_start_audios = args.skip_start_audios
97
- skip_end_audios = args.skip_end_audios
98
-
99
- logger.debug(f"{translations['audio_path']}: {input_audio}")
100
- logger.debug(f"{translations['output_path']}: {output_dataset}")
101
- logger.debug(f"{translations['resample']}: {resample}")
102
- if resample: logger.debug(f"{translations['sample_rate']}: {resample_sr}")
103
- logger.debug(f"{translations['clear_dataset']}: {clean_dataset}")
104
- if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
105
- logger.debug(f"{translations['separator_music']}: {separator_music}")
106
- logger.debug(f"{translations['dereveb_audio']}: {separator_reverb}")
107
- if separator_music: logger.debug(f"{translations['training_version']}: {kim_vocal_version}")
108
- logger.debug(f"{translations['segments_size']}: {segments_size}")
109
- logger.debug(f"{translations['overlap']}: {overlap}")
110
- logger.debug(f"Hop length: {hop_length}")
111
- logger.debug(f"{translations['batch_size']}: {batch_size}")
112
- logger.debug(f"{translations['denoise_mdx']}: {denoise_mdx}")
113
- logger.debug(f"{translations['skip']}: {skip}")
114
- if skip: logger.debug(f"{translations['skip_start']}: {skip_start_audios}")
115
- if skip: logger.debug(f"{translations['skip_end']}: {skip_end_audios}")
116
-
117
-
118
- if kim_vocal_version != 1 and kim_vocal_version != 2: raise ValueError(translations["version_not_valid"])
119
- if separator_reverb and not separator_music: raise ValueError(translations["create_dataset_value_not_valid"])
120
-
121
- start_time = time.time()
122
-
123
-
124
- try:
125
- paths = []
126
-
127
- if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
128
-
129
- urls = input_audio.replace(", ", ",").split(",")
130
-
131
- for url in urls:
132
- path = downloader(url, urls.index(url))
133
- paths.append(path)
134
-
135
- if skip:
136
- skip_start_audios = skip_start_audios.replace(", ", ",").split(",")
137
- skip_end_audios = skip_end_audios.replace(", ", ",").split(",")
138
-
139
- if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
140
- logger.warning(translations["skip<audio"])
141
- sys.exit(1)
142
- elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
143
- logger.warning(translations["skip>audio"])
144
- sys.exit(1)
145
- else:
146
- for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
147
- skip_start(audio, skip_start_audio)
148
- skip_end(audio, skip_end_audio)
149
-
150
- if separator_music:
151
- separator_paths = []
152
-
153
- for audio in paths:
154
- vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size)
155
-
156
- if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size)
157
- separator_paths.append(vocals)
158
-
159
- paths = separator_paths
160
-
161
- processed_paths = []
162
-
163
- for audio in paths:
164
- output = process_audio(audio)
165
- processed_paths.append(output)
166
-
167
- paths = processed_paths
168
-
169
- for audio_path in paths:
170
- data, sample_rate = sf.read(audio_path)
171
-
172
- if resample_sr != sample_rate and resample_sr > 0 and resample:
173
- data = librosa.resample(data, orig_sr=sample_rate, target_sr=resample_sr)
174
- sample_rate = resample_sr
175
-
176
- if clean_dataset: data = nr.reduce_noise(y=data, prop_decrease=clean_strength)
177
-
178
-
179
- sf.write(audio_path, data, sample_rate)
180
- except Exception as e:
181
- raise RuntimeError(f"{translations['create_dataset_error']}: {e}")
182
- finally:
183
- for audio in paths:
184
- shutil.move(audio, output_dataset)
185
-
186
- if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
187
-
188
-
189
- elapsed_time = time.time() - start_time
190
- logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
191
-
192
-
193
- def downloader(url, name):
194
- with warnings.catch_warnings():
195
- warnings.simplefilter("ignore")
196
-
197
- ydl_opts = {
198
- 'format': 'bestaudio/best',
199
- 'outtmpl': os.path.join(dataset_temp, f"{name}"),
200
- 'postprocessors': [{
201
- 'key': 'FFmpegExtractAudio',
202
- 'preferredcodec': 'wav',
203
- 'preferredquality': '192',
204
- }],
205
- 'noplaylist': True,
206
- 'verbose': False,
207
- }
208
-
209
- logger.info(f"{translations['starting_download']}: {url}...")
210
- with yt_dlp.YoutubeDL(ydl_opts) as ydl:
211
- ydl.extract_info(url)
212
- logger.info(f"{translations['download_success']}: {url}")
213
-
214
- return os.path.join(dataset_temp, f"{name}" + ".wav")
215
-
216
-
217
- def skip_start(input_file, seconds):
218
- data, sr = sf.read(input_file)
219
-
220
- total_duration = len(data) / sr
221
-
222
- if seconds <= 0: logger.warning(translations["=<0"])
223
- elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
224
- else:
225
- logger.info(f"{translations['skip_start']}: {input_file}...")
226
-
227
- sf.write(input_file, data[int(seconds * sr):], sr)
228
-
229
- logger.info(translations["skip_start_audio"].format(input_file=input_file))
230
-
231
-
232
- def skip_end(input_file, seconds):
233
- data, sr = sf.read(input_file)
234
-
235
- total_duration = len(data) / sr
236
-
237
- if seconds <= 0: logger.warning(translations["=<0"])
238
- elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
239
- else:
240
- logger.info(f"{translations['skip_end']}: {input_file}...")
241
-
242
- sf.write(input_file, data[:-int(seconds * sr)], sr)
243
-
244
- logger.info(translations["skip_end_audio"].format(input_file=input_file))
245
-
246
-
247
- def process_audio(file_path):
248
- try:
249
- song = AudioSegment.from_file(file_path)
250
- nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)
251
-
252
- cut_files = []
253
-
254
- for i, (start_i, end_i) in enumerate(nonsilent_parts):
255
- chunk = song[start_i:end_i]
256
-
257
- if len(chunk) >= 30:
258
- chunk_file_path = os.path.join(os.path.dirname(file_path), f"chunk{i}.wav")
259
- if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
260
-
261
- chunk.export(chunk_file_path, format="wav")
262
-
263
- cut_files.append(chunk_file_path)
264
- else: logger.warning(translations["skip_file"].format(i=i, chunk=len(chunk)))
265
-
266
- logger.info(f"{translations['split_total']}: {len(cut_files)}")
267
-
268
- def extract_number(filename):
269
- match = re.search(r'_(\d+)', filename)
270
-
271
- return int(match.group(1)) if match else 0
272
-
273
- cut_files = sorted(cut_files, key=extract_number)
274
-
275
- combined = AudioSegment.empty()
276
-
277
- for file in cut_files:
278
- combined += AudioSegment.from_file(file)
279
-
280
- output_path = os.path.splitext(file_path)[0] + "_processed" + ".wav"
281
-
282
- logger.info(translations["merge_audio"])
283
-
284
- combined.export(output_path, format="wav")
285
-
286
- return output_path
287
- except Exception as e:
288
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
289
-
290
-
291
- def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size):
292
- if not os.path.exists(input):
293
- logger.warning(translations["input_not_valid"])
294
- return None
295
-
296
- if not os.path.exists(output):
297
- logger.warning(translations["output_not_valid"])
298
- return None
299
-
300
- model = f"Kim_Vocal_{version}.onnx"
301
-
302
- logger.info(translations["separator_process"].format(input=input))
303
- output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
304
-
305
- for f in output_separator:
306
- path = os.path.join(output, f)
307
-
308
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
309
-
310
- if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
311
- elif '_(Vocals)_' in f:
312
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
313
- os.rename(path, rename_file)
314
-
315
- logger.info(f": {rename_file}")
316
- return rename_file
317
-
318
-
319
- def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size):
320
- reverb_models = "Reverb_HQ_By_FoxJoy.onnx"
321
-
322
- if not os.path.exists(input):
323
- logger.warning(translations["input_not_valid"])
324
- return None
325
-
326
- if not os.path.exists(output):
327
- logger.warning(translations["output_not_valid"])
328
- return None
329
-
330
- logger.info(f"{translations['dereverb']}: {input}...")
331
- output_dereverb = separator_main(audio_file=input, model_filename=reverb_models, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise)
332
-
333
- for f in output_dereverb:
334
- path = os.path.join(output, f)
335
-
336
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
337
-
338
- if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
339
- elif '_(No Reverb)_' in f:
340
- rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
341
- os.rename(path, rename_file)
342
-
343
- logger.info(f"{translations['dereverb_success']}: {rename_file}")
344
- return rename_file
345
-
346
-
347
- def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True):
348
- separator = Separator(
349
- log_formatter=file_formatter,
350
- log_level=logging.INFO,
351
- output_dir=output_dir,
352
- output_format=output_format,
353
- output_bitrate=None,
354
- normalization_threshold=0.9,
355
- output_single_stem=None,
356
- invert_using_spec=False,
357
- sample_rate=44100,
358
- mdx_params={
359
- "hop_length": mdx_hop_length,
360
- "segment_size": mdx_segment_size,
361
- "overlap": mdx_overlap,
362
- "batch_size": mdx_batch_size,
363
- "enable_denoise": mdx_enable_denoise,
364
- },
365
- )
366
-
367
- separator.load_model(model_filename=model_filename)
368
- return separator.separate(audio_file)
369
-
370
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/create_index.py DELETED
@@ -1,120 +0,0 @@
1
- import os
2
- import sys
3
- import faiss
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- import numpy as np
9
-
10
- from multiprocessing import cpu_count
11
- from sklearn.cluster import MiniBatchKMeans
12
-
13
-
14
- now_dir = os.getcwd()
15
- sys.path.append(now_dir)
16
-
17
- from main.configs.config import Config
18
-
19
- translations = Config().translations
20
-
21
-
22
- def parse_arguments() -> tuple:
23
- parser = argparse.ArgumentParser()
24
- parser.add_argument("--model_name", type=str, required=True)
25
- parser.add_argument("--rvc_version", type=str, default="v2")
26
- parser.add_argument("--index_algorithm", type=str, default="Auto")
27
-
28
- args = parser.parse_args()
29
- return args
30
-
31
-
32
- def main():
33
- args = parse_arguments()
34
-
35
- exp_dir = os.path.join("assets", "logs", args.model_name)
36
- version = args.rvc_version
37
- index_algorithm = args.index_algorithm
38
-
39
- log_file = os.path.join(exp_dir, "create_index.log")
40
- logger = logging.getLogger(__name__)
41
-
42
-
43
- if logger.hasHandlers(): logger.handlers.clear()
44
- else:
45
- console_handler = logging.StreamHandler()
46
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
47
-
48
- console_handler.setFormatter(console_formatter)
49
- console_handler.setLevel(logging.INFO)
50
-
51
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
52
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
53
-
54
- file_handler.setFormatter(file_formatter)
55
- file_handler.setLevel(logging.DEBUG)
56
-
57
- logger.addHandler(console_handler)
58
- logger.addHandler(file_handler)
59
- logger.setLevel(logging.DEBUG)
60
-
61
- logger.debug(f"{translations['modelname']}: {args.model_name}")
62
- logger.debug(f"{translations['model_path']}: {exp_dir}")
63
- logger.debug(f"{translations['training_version']}: {version}")
64
- logger.debug(f"{translations['index_algorithm_info']}: {index_algorithm}")
65
-
66
-
67
- try:
68
- feature_dir = os.path.join(exp_dir, f"{version}_extracted")
69
- model_name = os.path.basename(exp_dir)
70
-
71
- npys = []
72
- listdir_res = sorted(os.listdir(feature_dir))
73
-
74
- for name in listdir_res:
75
- file_path = os.path.join(feature_dir, name)
76
- phone = np.load(file_path)
77
- npys.append(phone)
78
-
79
- big_npy = np.concatenate(npys, axis=0)
80
- big_npy_idx = np.arange(big_npy.shape[0])
81
-
82
- np.random.shuffle(big_npy_idx)
83
-
84
- big_npy = big_npy[big_npy_idx]
85
-
86
- if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
87
-
88
- np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
89
-
90
- n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
91
-
92
- index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
93
-
94
- index_ivf_trained = faiss.extract_index_ivf(index_trained)
95
- index_ivf_trained.nprobe = 1
96
-
97
- index_trained.train(big_npy)
98
-
99
- faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
100
-
101
- index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
102
- index_ivf_added = faiss.extract_index_ivf(index_added)
103
-
104
- index_ivf_added.nprobe = 1
105
- index_added.train(big_npy)
106
-
107
- batch_size_add = 8192
108
-
109
- for i in range(0, big_npy.shape[0], batch_size_add):
110
- index_added.add(big_npy[i : i + batch_size_add])
111
-
112
- index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
113
-
114
- faiss.write_index(index_added, index_filepath_added)
115
-
116
- logger.info(f"{translations['save_index']} '{index_filepath_added}'")
117
- except Exception as e:
118
- logger.error(f"{translations['create_index_error']}: {e}")
119
-
120
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/extract.py DELETED
@@ -1,450 +0,0 @@
1
- import os
2
- import gc
3
- import sys
4
- import time
5
- import tqdm
6
- import torch
7
- import shutil
8
- import codecs
9
- import pyworld
10
- import librosa
11
- import logging
12
- import argparse
13
- import warnings
14
- import subprocess
15
- import torchcrepe
16
- import parselmouth
17
- import logging.handlers
18
-
19
- import numpy as np
20
- import soundfile as sf
21
- import torch.nn.functional as F
22
-
23
- from random import shuffle
24
- from functools import partial
25
- from multiprocessing import Pool
26
- from distutils.util import strtobool
27
- from fairseq import checkpoint_utils
28
- from concurrent.futures import ThreadPoolExecutor, as_completed
29
-
30
- now_dir = os.getcwd()
31
- sys.path.append(now_dir)
32
-
33
- from main.configs.config import Config
34
- from main.library.predictors.FCPE import FCPE
35
- from main.library.predictors.RMVPE import RMVPE
36
-
37
- logging.getLogger("wget").setLevel(logging.ERROR)
38
-
39
- warnings.filterwarnings("ignore", category=FutureWarning)
40
- warnings.filterwarnings("ignore", category=UserWarning)
41
-
42
- logger = logging.getLogger(__name__)
43
- logger.propagate = False
44
-
45
- config = Config()
46
- translations = config.translations
47
-
48
- def parse_arguments() -> tuple:
49
- parser = argparse.ArgumentParser()
50
- parser.add_argument("--model_name", type=str, required=True)
51
- parser.add_argument("--rvc_version", type=str, default="v2")
52
- parser.add_argument("--f0_method", type=str, default="rmvpe")
53
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
54
- parser.add_argument("--hop_length", type=int, default=128)
55
- parser.add_argument("--cpu_cores", type=int, default=2)
56
- parser.add_argument("--gpu", type=str, default="-")
57
- parser.add_argument("--sample_rate", type=int, required=True)
58
- parser.add_argument("--embedder_model", type=str, default="contentvec_base")
59
-
60
- args = parser.parse_args()
61
- return args
62
-
63
- def load_audio(file, sample_rate):
64
- try:
65
- file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
66
- audio, sr = sf.read(file)
67
-
68
- if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
69
- if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
70
- except Exception as e:
71
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
72
-
73
- return audio.flatten()
74
-
75
- def check_rmvpe_fcpe(method):
76
- if method == "rmvpe" and not os.path.exists(os.path.join("assets", "model", "predictors", "rmvpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "rmvpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
77
- elif method == "fcpe" and not os.path.exists(os.path.join("assets", "model", "predictors", "fcpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "fcpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
78
-
79
- def check_hubert(hubert):
80
- if hubert == "contentvec_base" or hubert == "hubert_base" or hubert == "japanese_hubert_base" or hubert == "korean_hubert_base" or hubert == "chinese_hubert_base":
81
- model_path = os.path.join(now_dir, "assets", "model", "embedders", hubert + '.pt')
82
-
83
- if not os.path.exists(model_path): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + f"{hubert}.pt", "-P", os.path.join("assets", "model", "embedders")], check=True)
84
-
85
- def generate_config(rvc_version, sample_rate, model_path):
86
- config_path = os.path.join("main", "configs", rvc_version, f"{sample_rate}.json")
87
- config_save_path = os.path.join(model_path, "config.json")
88
- if not os.path.exists(config_save_path): shutil.copy(config_path, config_save_path)
89
-
90
-
91
- def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
92
- gt_wavs_dir = os.path.join(model_path, "sliced_audios")
93
- feature_dir = os.path.join(model_path, f"{rvc_version}_extracted")
94
-
95
- f0_dir, f0nsf_dir = None, None
96
-
97
- if pitch_guidance:
98
- f0_dir = os.path.join(model_path, "f0")
99
- f0nsf_dir = os.path.join(model_path, "f0_voiced")
100
-
101
- gt_wavs_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir))
102
- feature_files = set(name.split(".")[0] for name in os.listdir(feature_dir))
103
-
104
- if pitch_guidance:
105
- f0_files = set(name.split(".")[0] for name in os.listdir(f0_dir))
106
- f0nsf_files = set(name.split(".")[0] for name in os.listdir(f0nsf_dir))
107
- names = gt_wavs_files & feature_files & f0_files & f0nsf_files
108
- else: names = gt_wavs_files & feature_files
109
-
110
- options = []
111
- mute_base_path = os.path.join(now_dir, "assets", "logs", "mute")
112
-
113
- for name in names:
114
- if pitch_guidance: options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0")
115
- else: options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
116
-
117
- mute_audio_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav")
118
- mute_feature_path = os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
119
-
120
- for _ in range(2):
121
- if pitch_guidance:
122
- mute_f0_path = os.path.join(mute_base_path, "f0", "mute.wav.npy")
123
- mute_f0nsf_path = os.path.join(mute_base_path, "f0_voiced", "mute.wav.npy")
124
- options.append(f"{mute_audio_path}|{mute_feature_path}|{mute_f0_path}|{mute_f0nsf_path}|0")
125
- else: options.append(f"{mute_audio_path}|{mute_feature_path}|0")
126
-
127
- shuffle(options)
128
-
129
- with open(os.path.join(model_path, "filelist.txt"), "w") as f:
130
- f.write("\n".join(options))
131
-
132
- def setup_paths(exp_dir, version = None):
133
- wav_path = os.path.join(exp_dir, "sliced_audios_16k")
134
- if version:
135
- out_path = os.path.join(exp_dir, "v1_extracted" if version == "v1" else "v2_extracted")
136
- os.makedirs(out_path, exist_ok=True)
137
-
138
- return wav_path, out_path
139
- else:
140
- output_root1 = os.path.join(exp_dir, "f0")
141
- output_root2 = os.path.join(exp_dir, "f0_voiced")
142
-
143
- os.makedirs(output_root1, exist_ok=True)
144
- os.makedirs(output_root2, exist_ok=True)
145
-
146
- return wav_path, output_root1, output_root2
147
-
148
- def read_wave(wav_path, normalize = False):
149
- wav, sr = sf.read(wav_path)
150
- assert sr == 16000, translations["sr_not_16000"]
151
-
152
- feats = torch.from_numpy(wav).float()
153
-
154
- if config.is_half: feats = feats.half()
155
- if feats.dim() == 2: feats = feats.mean(-1)
156
-
157
- feats = feats.view(1, -1)
158
-
159
- if normalize: feats = F.layer_norm(feats, feats.shape)
160
-
161
- return feats
162
-
163
- def get_device(gpu_index):
164
- if gpu_index == "cpu": return "cpu"
165
-
166
- try:
167
- index = int(gpu_index)
168
- if index < torch.cuda.device_count(): return f"cuda:{index}"
169
- else: logger.warning(translations["gpu_not_valid"])
170
- except ValueError:
171
- logger.warning(translations["gpu_not_valid"])
172
-
173
- return "cpu"
174
-
175
- class FeatureInput:
176
- def __init__(self, sample_rate=16000, hop_size=160, device="cpu"):
177
- self.fs = sample_rate
178
- self.hop = hop_size
179
- self.f0_bin = 256
180
- self.f0_max = 1100.0
181
- self.f0_min = 50.0
182
- self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
183
- self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
184
- self.device = device
185
-
186
- def compute_f0(self, np_arr, f0_method, hop_length):
187
- if f0_method == "pm": return self.get_pm(np_arr)
188
- elif f0_method == 'dio': return self.get_dio(np_arr)
189
- elif f0_method == "crepe": return self.get_crepe(np_arr, int(hop_length))
190
- elif f0_method == "crepe-tiny": return self.get_crepe(np_arr, int(hop_length), "tiny")
191
- elif f0_method == "fcpe": return self.get_fcpe(np_arr, int(hop_length))
192
- elif f0_method == "rmvpe": return self.get_rmvpe(np_arr)
193
- elif f0_method == "harvest": return self.get_harvest(np_arr)
194
- else: raise ValueError(translations["method_not_valid"])
195
-
196
- def get_pm(self, x):
197
- time_step = 160 / 16000 * 1000
198
- f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
199
- pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
200
- if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
201
-
202
- return f0
203
-
204
- def get_dio(self, x):
205
- f0, t = pyworld.dio(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
206
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
207
-
208
- return f0
209
-
210
- def get_crepe(self, x, hop_length, model="full"):
211
- audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
212
- audio /= torch.quantile(torch.abs(audio), 0.999)
213
- audio = audio.unsqueeze(0)
214
-
215
- pitch = torchcrepe.predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True)
216
-
217
- source = pitch.squeeze(0).cpu().float().numpy()
218
- source[source < 0.001] = np.nan
219
- target = np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source)
220
-
221
- return np.nan_to_num(target)
222
-
223
-
224
- def get_fcpe(self, x, hop_length):
225
- self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=self.f0_min, f0_max=self.f0_max, dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03)
226
- f0 = self.model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
227
- del self.model_fcpe
228
- gc.collect()
229
- return f0
230
-
231
-
232
- def get_rmvpe(self, x):
233
- self.model_rmvpe = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=False, device=self.device)
234
- return self.model_rmvpe.infer_from_audio(x, thred=0.03)
235
-
236
-
237
- def get_harvest(self, x):
238
- f0, t = pyworld.harvest(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
239
- f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
240
- return f0
241
-
242
-
243
- def coarse_f0(self, f0):
244
- f0_mel = 1127 * np.log(1 + f0 / 700)
245
- f0_mel = np.clip((f0_mel - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)
246
- return np.rint(f0_mel).astype(int)
247
-
248
-
249
- def process_file(self, file_info, f0_method, hop_length):
250
- inp_path, opt_path1, opt_path2, np_arr = file_info
251
-
252
- if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
253
-
254
- try:
255
- feature_pit = self.compute_f0(np_arr, f0_method, hop_length)
256
- np.save(opt_path2, feature_pit, allow_pickle=False)
257
- coarse_pit = self.coarse_f0(feature_pit)
258
- np.save(opt_path1, coarse_pit, allow_pickle=False)
259
- except Exception as e:
260
- raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
261
-
262
-
263
- def process_files(self, files, f0_method, hop_length, pbar):
264
- for file_info in files:
265
- self.process_file(file_info, f0_method, hop_length)
266
- pbar.update()
267
-
268
- def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus):
269
- input_root, *output_roots = setup_paths(exp_dir)
270
-
271
- if len(output_roots) == 2: output_root1, output_root2 = output_roots
272
- else:
273
- output_root1 = output_roots[0]
274
- output_root2 = None
275
-
276
- paths = [
277
- (
278
- os.path.join(input_root, name),
279
- os.path.join(output_root1, name) if output_root1 else None,
280
- os.path.join(output_root2, name) if output_root2 else None,
281
- load_audio(os.path.join(input_root, name), 16000),
282
- )
283
- for name in sorted(os.listdir(input_root))
284
- if "spec" not in name
285
- ]
286
-
287
- logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
288
- start_time = time.time()
289
-
290
- if gpus != "-":
291
- gpus = gpus.split("-")
292
- num_gpus = len(gpus)
293
-
294
- process_partials = []
295
-
296
- pbar = tqdm.tqdm(total=len(paths), desc=translations["extract_f0"])
297
-
298
- for idx, gpu in enumerate(gpus):
299
- device = get_device(gpu)
300
- feature_input = FeatureInput(device=device)
301
-
302
- part_paths = paths[idx::num_gpus]
303
- process_partials.append((feature_input, part_paths))
304
-
305
- with ThreadPoolExecutor() as executor:
306
- futures = [executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, pbar) for feature_input, part_paths in process_partials]
307
-
308
- for future in as_completed(futures):
309
- pbar.update(1)
310
- future.result()
311
-
312
- pbar.close()
313
- else:
314
- feature_input = FeatureInput(device="cpu")
315
-
316
- with tqdm.tqdm(total=len(paths), desc=translations["extract_f0"]) as pbar:
317
- with Pool(processes=num_processes) as pool:
318
- process_file_partial = partial(feature_input.process_file, f0_method=f0_method, hop_length=hop_length)
319
-
320
- for _ in pool.imap_unordered(process_file_partial, paths):
321
- pbar.update(1)
322
-
323
-
324
- elapsed_time = time.time() - start_time
325
- logger.info(translations["extract_f0_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
326
-
327
- def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg):
328
- wav_file_path = os.path.join(wav_path, file)
329
- out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
330
-
331
- if os.path.exists(out_file_path): return
332
-
333
- feats = read_wave(wav_file_path, normalize=saved_cfg.task.normalize)
334
- dtype = torch.float16 if device.startswith("cuda") else torch.float32
335
- feats = feats.to(dtype).to(device)
336
-
337
- padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(dtype).to(device)
338
-
339
- inputs = {
340
- "source": feats,
341
- "padding_mask": padding_mask,
342
- "output_layer": 9 if version == "v1" else 12,
343
- }
344
-
345
- with torch.no_grad():
346
- model = model.to(device).to(dtype)
347
- logits = model.extract_features(**inputs)
348
- feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
349
-
350
- feats = feats.squeeze(0).float().cpu().numpy()
351
-
352
- if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
353
- else: logger.warning(f"{file} {translations['NaN']}")
354
-
355
- def run_embedding_extraction(exp_dir, version, gpus, embedder_model):
356
- wav_path, out_path = setup_paths(exp_dir, version)
357
-
358
- logger.info(translations["start_extract_hubert"])
359
- start_time = time.time()
360
-
361
- try:
362
- models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')], suffix="")
363
- except Exception as e:
364
- raise ImportError(translations["read_model_error"].format(e=e))
365
-
366
- model = models[0]
367
- devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
368
-
369
- paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
370
-
371
- if not paths:
372
- logger.warning(translations["not_found_audio_file"])
373
- sys.exit(1)
374
-
375
- pbar = tqdm.tqdm(total=len(paths) * len(devices), desc=translations["extract_hubert"])
376
-
377
- tasks = [(file, wav_path, out_path, model, device, version, saved_cfg) for file in paths for device in devices]
378
-
379
- for task in tasks:
380
- try:
381
- process_file_embedding(*task)
382
- except Exception as e:
383
- raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
384
-
385
- pbar.update(1)
386
-
387
- pbar.close()
388
-
389
- elapsed_time = time.time() - start_time
390
- logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
391
-
392
- if __name__ == "__main__":
393
- args = parse_arguments()
394
-
395
- exp_dir = os.path.join("assets", "logs", args.model_name)
396
- f0_method = args.f0_method
397
- hop_length = args.hop_length
398
- num_processes = args.cpu_cores
399
- gpus = args.gpu
400
- version = args.rvc_version
401
- pitch_guidance = args.pitch_guidance
402
- sample_rate = args.sample_rate
403
- embedder_model = args.embedder_model
404
-
405
- check_rmvpe_fcpe(f0_method)
406
- check_hubert(embedder_model)
407
-
408
- if len([f for f in os.listdir(os.path.join(exp_dir, "sliced_audios")) if os.path.isfile(os.path.join(exp_dir, "sliced_audios", f))]) < 1 or len([f for f in os.listdir(os.path.join(exp_dir, "sliced_audios_16k")) if os.path.isfile(os.path.join(exp_dir, "sliced_audios_16k", f))]) < 1: raise FileNotFoundError("Không tìm thấy dữ liệu được xử lý, vui lòng xử lý lại âm thanh")
409
-
410
- log_file = os.path.join(exp_dir, "extract.log")
411
-
412
- if logger.hasHandlers(): logger.handlers.clear()
413
- else:
414
- console_handler = logging.StreamHandler()
415
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
416
-
417
- console_handler.setFormatter(console_formatter)
418
- console_handler.setLevel(logging.INFO)
419
-
420
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
421
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
422
-
423
- file_handler.setFormatter(file_formatter)
424
- file_handler.setLevel(logging.DEBUG)
425
-
426
- logger.addHandler(console_handler)
427
- logger.addHandler(file_handler)
428
- logger.setLevel(logging.DEBUG)
429
-
430
- logger.debug(f"{translations['modelname']}: {args.model_name}")
431
- logger.debug(f"{translations['export_process']}: {exp_dir}")
432
- logger.debug(f"{translations['f0_method']}: {f0_method}")
433
- logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
434
- logger.debug(f"{translations['cpu_core']}: {num_processes}")
435
- logger.debug(f"Gpu: {gpus}")
436
- if f0_method == "crepe" or f0_method == "crepe-tiny" or f0_method == "fcpe": logger.debug(f"Hop length: {hop_length}")
437
- logger.debug(f"{translations['training_version']}: {version}")
438
- logger.debug(f"{translations['extract_f0']}: {pitch_guidance}")
439
- logger.debug(f"{translations['hubert_model']}: {embedder_model}")
440
-
441
- try:
442
- run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus)
443
- run_embedding_extraction(exp_dir, version, gpus, embedder_model)
444
-
445
- generate_config(version, sample_rate, exp_dir)
446
- generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
447
- except Exception as e:
448
- logger.error(f"{translations['extract_error']}: {e}")
449
-
450
- logger.info(f"{translations['extract_success']} {args.model_name}.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/preprocess.py DELETED
@@ -1,360 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import torch
5
- import logging
6
- import librosa
7
- import argparse
8
- import logging.handlers
9
-
10
- import numpy as np
11
- import soundfile as sf
12
- import multiprocessing
13
- import noisereduce as nr
14
-
15
- from tqdm import tqdm
16
- from scipy import signal
17
- from scipy.io import wavfile
18
- from distutils.util import strtobool
19
- from concurrent.futures import ProcessPoolExecutor, as_completed
20
-
21
- now_directory = os.getcwd()
22
- sys.path.append(now_directory)
23
-
24
- from main.configs.config import Config
25
-
26
- logger = logging.getLogger(__name__)
27
-
28
-
29
- logging.getLogger("numba.core.byteflow").setLevel(logging.ERROR)
30
- logging.getLogger("numba.core.ssa").setLevel(logging.ERROR)
31
- logging.getLogger("numba.core.interpreter").setLevel(logging.ERROR)
32
-
33
- OVERLAP = 0.3
34
- MAX_AMPLITUDE = 0.9
35
- ALPHA = 0.75
36
- HIGH_PASS_CUTOFF = 48
37
- SAMPLE_RATE_16K = 16000
38
-
39
-
40
- config = Config()
41
- per = 3.0 if config.is_half else 3.7
42
- translations = config.translations
43
-
44
-
45
- def parse_arguments() -> tuple:
46
- parser = argparse.ArgumentParser()
47
- parser.add_argument("--model_name", type=str, required=True)
48
- parser.add_argument("--dataset_path", type=str, default="./dataset")
49
- parser.add_argument("--sample_rate", type=int, required=True)
50
- parser.add_argument("--cpu_cores", type=int, default=2)
51
- parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
52
- parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
53
- parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
54
- parser.add_argument("--clean_strength", type=float, default=0.7)
55
-
56
- args = parser.parse_args()
57
- return args
58
-
59
-
60
- def load_audio(file, sample_rate):
61
- try:
62
- file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
63
- audio, sr = sf.read(file)
64
-
65
- if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
66
- if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
67
- except Exception as e:
68
- raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
69
-
70
- return audio.flatten()
71
-
72
- class Slicer:
73
- def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
74
- if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
75
- if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
76
-
77
- min_interval = sr * min_interval / 1000
78
- self.threshold = 10 ** (threshold / 20.0)
79
- self.hop_size = round(sr * hop_size / 1000)
80
- self.win_size = min(round(min_interval), 4 * self.hop_size)
81
- self.min_length = round(sr * min_length / 1000 / self.hop_size)
82
- self.min_interval = round(min_interval / self.hop_size)
83
- self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
84
-
85
- def _apply_slice(self, waveform, begin, end):
86
- start_idx = begin * self.hop_size
87
-
88
- if len(waveform.shape) > 1:
89
- end_idx = min(waveform.shape[1], end * self.hop_size)
90
- return waveform[:, start_idx:end_idx]
91
- else:
92
- end_idx = min(waveform.shape[0], end * self.hop_size)
93
- return waveform[start_idx:end_idx]
94
-
95
- def slice(self, waveform):
96
- samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
97
- if samples.shape[0] <= self.min_length: return [waveform]
98
-
99
- rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
100
-
101
- sil_tags = []
102
- silence_start, clip_start = None, 0
103
-
104
- for i, rms in enumerate(rms_list):
105
- if rms < self.threshold:
106
- if silence_start is None: silence_start = i
107
- continue
108
-
109
- if silence_start is None: continue
110
-
111
- is_leading_silence = silence_start == 0 and i > self.max_sil_kept
112
- need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
113
-
114
- if not is_leading_silence and not need_slice_middle:
115
- silence_start = None
116
- continue
117
-
118
- if i - silence_start <= self.max_sil_kept:
119
- pos = rms_list[silence_start : i + 1].argmin() + silence_start
120
- if silence_start == 0: sil_tags.append((0, pos))
121
- else: sil_tags.append((pos, pos))
122
-
123
- clip_start = pos
124
-
125
- elif i - silence_start <= self.max_sil_kept * 2:
126
- pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
127
-
128
- pos += i - self.max_sil_kept
129
-
130
- pos_l = (rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start)
131
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
132
-
133
- if silence_start == 0:
134
- sil_tags.append((0, pos_r))
135
- clip_start = pos_r
136
- else:
137
- sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
138
- clip_start = max(pos_r, pos)
139
- else:
140
- pos_l = (rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start)
141
- pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
142
-
143
- if silence_start == 0: sil_tags.append((0, pos_r))
144
- else: sil_tags.append((pos_l, pos_r))
145
-
146
- clip_start = pos_r
147
- silence_start = None
148
-
149
- total_frames = rms_list.shape[0]
150
-
151
- if (silence_start is not None and total_frames - silence_start >= self.min_interval):
152
- silence_end = min(total_frames, silence_start + self.max_sil_kept)
153
- pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
154
- sil_tags.append((pos, total_frames + 1))
155
-
156
- if not sil_tags: return [waveform]
157
- else:
158
- chunks = []
159
-
160
- if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
161
-
162
- for i in range(len(sil_tags) - 1):
163
- chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
164
-
165
- if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
166
-
167
- return chunks
168
-
169
-
170
- def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
171
- padding = (int(frame_length // 2), int(frame_length // 2))
172
- y = np.pad(y, padding, mode=pad_mode)
173
-
174
- axis = -1
175
- out_strides = y.strides + tuple([y.strides[axis]])
176
- x_shape_trimmed = list(y.shape)
177
- x_shape_trimmed[axis] -= frame_length - 1
178
- out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
179
- xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
180
-
181
- target_axis = axis - 1 if axis < 0 else axis + 1
182
-
183
-
184
- xw = np.moveaxis(xw, -1, target_axis)
185
- slices = [slice(None)] * xw.ndim
186
- slices[axis] = slice(0, None, hop_length)
187
- x = xw[tuple(slices)]
188
-
189
- power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
190
- return np.sqrt(power)
191
-
192
-
193
- class PreProcess:
194
- def __init__(self, sr, exp_dir, per):
195
- self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
196
- self.sr = sr
197
- self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
198
- self.per = per
199
- self.exp_dir = exp_dir
200
- self.device = "cpu"
201
- self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
202
- self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
203
-
204
- os.makedirs(self.gt_wavs_dir, exist_ok=True)
205
- os.makedirs(self.wavs16k_dir, exist_ok=True)
206
-
207
- def _normalize_audio(self, audio: torch.Tensor):
208
- tmp_max = torch.abs(audio).max()
209
- if tmp_max > 2.5: return None
210
-
211
- return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
212
-
213
- def process_audio_segment(self, normalized_audio: np.ndarray, sid, idx0, idx1):
214
- if normalized_audio is None:
215
- logs(f"{sid}-{idx0}-{idx1}-filtered")
216
- return
217
-
218
- wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
219
- audio_16k = librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K)
220
- wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, audio_16k.astype(np.float32))
221
-
222
- def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
223
- try:
224
- audio = load_audio(path, self.sr)
225
-
226
- if process_effects:
227
- audio = signal.lfilter(self.b_high, self.a_high, audio)
228
- audio = self._normalize_audio(audio)
229
-
230
- if clean_dataset: audio = nr.reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength)
231
-
232
- idx1 = 0
233
-
234
- if cut_preprocess:
235
- for audio_segment in self.slicer.slice(audio):
236
- i = 0
237
-
238
- while 1:
239
- start = int(self.sr * (self.per - OVERLAP) * i)
240
- i += 1
241
-
242
- if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
243
- tmp_audio = audio_segment[start : start + int(self.per * self.sr)]
244
- self.process_audio_segment(tmp_audio, sid, idx0, idx1)
245
- idx1 += 1
246
- else:
247
- tmp_audio = audio_segment[start:]
248
- self.process_audio_segment(tmp_audio, sid, idx0, idx1)
249
- idx1 += 1
250
- break
251
-
252
- else: self.process_audio_segment(audio, sid, idx0, idx1)
253
- except Exception as e:
254
- raise RuntimeError(f"{translations['process_audio_error']}: {e}")
255
-
256
- def process_file(args):
257
- pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
258
- file_path, idx0, sid = file
259
-
260
- pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
261
-
262
- def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
263
- start_time = time.time()
264
-
265
- pp = PreProcess(sr, exp_dir, per)
266
- logger.info(translations["start_preprocess"].format(num_processes=num_processes))
267
-
268
- files = []
269
- idx = 0
270
-
271
- for root, _, filenames in os.walk(input_root):
272
- try:
273
- sid = 0 if root == input_root else int(os.path.basename(root))
274
- for f in filenames:
275
- if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg")):
276
- files.append((os.path.join(root, f), idx, sid))
277
- idx += 1
278
- except ValueError:
279
- raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
280
-
281
- with tqdm(total=len(files), desc=translations["preprocess"]) as pbar:
282
- with ProcessPoolExecutor(max_workers=num_processes) as executor:
283
- futures = [
284
- executor.submit(
285
- process_file,
286
- (
287
- pp,
288
- file,
289
- cut_preprocess,
290
- process_effects,
291
- clean_dataset,
292
- clean_strength,
293
- ),
294
- )
295
- for file in files
296
- ]
297
- for future in as_completed(futures):
298
- try:
299
- future.result()
300
- except Exception as e:
301
- raise RuntimeError(f"{translations['process_error']}: {e}")
302
-
303
- pbar.update(1)
304
-
305
- elapsed_time = time.time() - start_time
306
- logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
307
-
308
- if __name__ == "__main__":
309
- args = parse_arguments()
310
-
311
- experiment_directory = os.path.join("assets", "logs", args.model_name)
312
- num_processes = args.cpu_cores
313
- num_processes = multiprocessing.cpu_count() if num_processes is None else int(num_processes)
314
- dataset = args.dataset_path
315
- sample_rate = args.sample_rate
316
- cut_preprocess = args.cut_preprocess
317
- preprocess_effects = args.process_effects
318
- clean_dataset = args.clean_dataset
319
- clean_strength = args.clean_strength
320
-
321
- os.makedirs(experiment_directory, exist_ok=True)
322
-
323
- if len([f for f in os.listdir(os.path.join(dataset)) if os.path.isfile(os.path.join(dataset, f)) and f.lower().endswith((".wav", ".mp3", ".flac", ".ogg"))]) < 1: raise FileNotFoundError("Không tìm thấy dữ liệu")
324
-
325
- log_file = os.path.join(experiment_directory, "preprocess.log")
326
-
327
- if logger.hasHandlers(): logger.handlers.clear()
328
- else:
329
- console_handler = logging.StreamHandler()
330
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
331
-
332
- console_handler.setFormatter(console_formatter)
333
- console_handler.setLevel(logging.INFO)
334
-
335
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
336
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
337
-
338
- file_handler.setFormatter(file_formatter)
339
- file_handler.setLevel(logging.DEBUG)
340
-
341
- logger.addHandler(console_handler)
342
- logger.addHandler(file_handler)
343
- logger.setLevel(logging.DEBUG)
344
-
345
- logger.debug(f"{translations['modelname']}: {args.model_name}")
346
- logger.debug(f"{translations['export_process']}: {experiment_directory}")
347
- logger.debug(f"{translations['dataset_folder']}: {dataset}")
348
- logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
349
- logger.debug(f"{translations['cpu_core']}: {num_processes}")
350
- logger.debug(f"{translations['split_audio']}: {cut_preprocess}")
351
- logger.debug(f"{translations['preprocess_effect']}: {preprocess_effects}")
352
- logger.debug(f"{translations['clear_audio']}: {clean_dataset}")
353
- if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
354
-
355
- try:
356
- preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, per, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
357
- except Exception as e:
358
- logger.error(f"{translations['process_audio_error']} {e}")
359
-
360
- logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/separator_music.py DELETED
@@ -1,400 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import logging
5
- import argparse
6
- import logging.handlers
7
-
8
- import soundfile as sf
9
- import noisereduce as nr
10
-
11
- from pydub import AudioSegment
12
- from distutils.util import strtobool
13
-
14
- now_dir = os.getcwd()
15
- sys.path.append(now_dir)
16
-
17
- from main.configs.config import Config
18
- from main.library.algorithm.separator import Separator
19
-
20
- translations = Config().translations
21
-
22
- log_file = os.path.join("assets", "logs", "separator.log")
23
- logger = logging.getLogger(__name__)
24
-
25
- if logger.hasHandlers(): logger.handlers.clear()
26
- else:
27
- console_handler = logging.StreamHandler()
28
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
29
-
30
- console_handler.setFormatter(console_formatter)
31
- console_handler.setLevel(logging.INFO)
32
-
33
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
34
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
35
-
36
- file_handler.setFormatter(file_formatter)
37
- file_handler.setLevel(logging.DEBUG)
38
-
39
- logger.addHandler(console_handler)
40
- logger.addHandler(file_handler)
41
- logger.setLevel(logging.DEBUG)
42
-
43
- demucs_models = {
44
- "HT-Tuned": "htdemucs_ft.yaml",
45
- "HT-Normal": "htdemucs.yaml",
46
- "HD_MMI": "hdemucs_mmi.yaml",
47
- "HT_6S": "htdemucs_6s.yaml"
48
- }
49
-
50
- mdx_models = {
51
- "Main_340": "UVR-MDX-NET_Main_340.onnx",
52
- "Main_390": "UVR-MDX-NET_Main_390.onnx",
53
- "Main_406": "UVR-MDX-NET_Main_406.onnx",
54
- "Main_427": "UVR-MDX-NET_Main_427.onnx",
55
- "Main_438": "UVR-MDX-NET_Main_438.onnx",
56
- "Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx",
57
- "Inst_HQ_1": "UVR-MDX-NET_Inst_HQ_1.onnx",
58
- "Inst_HQ_2": "UVR-MDX-NET_Inst_HQ_2.onnx",
59
- "Inst_HQ_3": "UVR-MDX-NET_Inst_HQ_3.onnx",
60
- "Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx",
61
- "Kim_Vocal_1": "Kim_Vocal_1.onnx",
62
- "Kim_Vocal_2": "Kim_Vocal_2.onnx",
63
- "Kim_Inst": "Kim_Inst.onnx",
64
- "Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx",
65
- "Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx",
66
- "Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx",
67
- "Voc_FT": "UVR-MDX-NET-Voc_FT.onnx",
68
- "Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx",
69
- "MDXNET_9482": "UVR_MDXNET_9482.onnx",
70
- "Inst_1": "UVR-MDX-NET-Inst_1.onnx",
71
- "Inst_2": "UVR-MDX-NET-Inst_2.onnx",
72
- "Inst_3": "UVR-MDX-NET-Inst_3.onnx",
73
- "MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx",
74
- "MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx",
75
- "MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx",
76
- "Inst_Main": "UVR-MDX-NET-Inst_Main.onnx",
77
- "MDXNET_Main": "UVR_MDXNET_Main.onnx"
78
- }
79
-
80
- kara_models = {
81
- "Version-1": "UVR_MDXNET_KARA.onnx",
82
- "Version-2": "UVR_MDXNET_KARA_2.onnx"
83
- }
84
-
85
-
86
- def parse_arguments() -> tuple:
87
- parser = argparse.ArgumentParser()
88
- parser.add_argument("--input_path", type=str, required=True)
89
- parser.add_argument("--output_path", type=str, default="./audios")
90
- parser.add_argument("--format", type=str, default="wav")
91
- parser.add_argument("--shifts", type=int, default=10)
92
- parser.add_argument("--segments_size", type=int, default=256)
93
- parser.add_argument("--overlap", type=float, default=0.25)
94
- parser.add_argument("--mdx_hop_length", type=int, default=1024)
95
- parser.add_argument("--mdx_batch_size", type=int, default=1)
96
- parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
97
- parser.add_argument("--clean_strength", type=float, default=0.7)
98
- parser.add_argument("--backing_denoise", type=lambda x: bool(strtobool(x)), default=False)
99
- parser.add_argument("--demucs_model", type=str, default="HT-Normal")
100
- parser.add_argument("--kara_model", type=str, default="Version-1")
101
- parser.add_argument("--mdx_model", type=str, default="Main_340")
102
- parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
103
- parser.add_argument("--mdx", type=lambda x: bool(strtobool(x)), default=False)
104
- parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
105
- parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
106
- parser.add_argument("--reverb_denoise", type=lambda x: bool(strtobool(x)), default=False)
107
- parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
108
-
109
- args = parser.parse_args()
110
- return args
111
-
112
- def main():
113
- start_time = time.time()
114
-
115
- try:
116
- args = parse_arguments()
117
-
118
- input_path = args.input_path
119
- output_path = args.output_path
120
- export_format = args.format
121
- shifts = args.shifts
122
- segments_size = args.segments_size
123
- overlap = args.overlap
124
- hop_length = args.mdx_hop_length
125
- batch_size = args.mdx_batch_size
126
- clean_audio = args.clean_audio
127
- clean_strength = args.clean_strength
128
- backing_denoise = args.backing_denoise
129
- demucs_model = args.demucs_model
130
- kara_model = args.kara_model
131
- backing = args.backing
132
- mdx = args.mdx
133
- mdx_model = args.mdx_model
134
- mdx_denoise = args.mdx_denoise
135
- reverb = args.reverb
136
- reverb_denoise = args.reverb_denoise
137
- backing_reverb = args.backing_reverb
138
-
139
- if backing_reverb and not reverb:
140
- logger.warning(translations["turn_on_dereverb"])
141
- sys.exit(1)
142
-
143
- if backing_reverb and not backing:
144
- logger.warning(translations["turn_on_separator_backing"])
145
- sys.exit(1)
146
-
147
- logger.debug(f"{translations['audio_path']}: {input_path}")
148
- logger.debug(f"{translations['output_path']}: {output_path}")
149
- logger.debug(f"{translations['export_format']}: {export_format}")
150
- if not mdx: logger.debug(f"{translations['shift']}: {shifts}")
151
- logger.debug(f"{translations['segments_size']}: {segments_size}")
152
- logger.debug(f"{translations['overlap']}: {overlap}")
153
- if clean_audio: logger.debug(f"{translations['clear_audio']}: {clean_audio}")
154
- if clean_audio: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
155
- if not mdx: logger.debug(f"{translations['demucs_model']}: {demucs_model}")
156
- if backing: logger.debug(f"{translations['denoise_backing']}: {backing_denoise}")
157
- if backing: logger.debug(f"{translations['backing_model_ver']}: {kara_model}")
158
- if backing: logger.debug(f"{translations['separator_backing']}: {backing}")
159
- if mdx: logger.debug(f"{translations['use_mdx']}: {mdx}")
160
- if mdx: logger.debug(f"{translations['mdx_model']}: {mdx_model}")
161
- if mdx: logger.debug(f"{translations['denoise_mdx']}: {mdx_denoise}")
162
- if mdx or backing or reverb: logger.debug(f"Hop length: {hop_length}")
163
- if mdx or backing or reverb: logger.debug(f"{translations['batch_size']}: {batch_size}")
164
- if reverb: logger.debug(f"{translations['dereveb_audio']}: {reverb}")
165
- if reverb: logger.debug(f"{translations['denoise_dereveb']}: {reverb_denoise}")
166
- if reverb: logger.debug(f"{translations['dereveb_backing']}: {backing_reverb}")
167
-
168
- if not mdx: vocals, instruments = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, demucs_model)
169
- else: vocals, instruments = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, mdx_model, hop_length, batch_size)
170
-
171
- if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, backing_denoise, kara_model, hop_length, batch_size)
172
- if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, reverb_denoise, reverb, backing, backing_reverb, hop_length, batch_size)
173
-
174
- original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
175
- main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing else os.path.join(output_path, f"Main_Vocals.{export_format}")
176
- backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
177
-
178
- if clean_audio:
179
- logger.info(f"{translations['clear_audio']}...")
180
- vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals)
181
- main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals)
182
- backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals)
183
-
184
- vocals_clean = nr.reduce_noise(y=vocal_data, prop_decrease=clean_strength)
185
-
186
- sf.write(original_output, vocals_clean, vocal_sr, format=export_format)
187
-
188
- if backing:
189
- mains_clean = nr.reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength)
190
- backing_clean = nr.reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength)
191
- sf.write(main_output, mains_clean, main_sr, format=export_format)
192
- sf.write(backing_output, backing_clean, backing_sr, format=export_format)
193
-
194
- logger.info(translations["clean_audio_success"])
195
- return original_output, instruments, main_output, backing_output
196
- except Exception as e:
197
- logger.error(f"{translations['separator_error']}: {e}")
198
- return None, None, None, None
199
-
200
- elapsed_time = time.time() - start_time
201
- logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
202
-
203
-
204
- def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model):
205
- if not os.path.exists(input):
206
- logger.warning(translations["input_not_valid"])
207
- sys.exit(1)
208
-
209
- if not os.path.exists(output):
210
- logger.warning(translations["output_not_valid"])
211
- sys.exit(1)
212
-
213
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
214
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
215
-
216
- model = demucs_models.get(demucs_model)
217
-
218
- segment_size = segments_size / 2
219
-
220
- logger.info(f"{translations['separator_process_2']}...")
221
-
222
- demucs_output = separator_main(audio_file=input, model_filename=model, output_format=format, output_dir=output, demucs_segment_size=segment_size, demucs_shifts=shifts, demucs_overlap=overlap)
223
-
224
- for f in demucs_output:
225
- path = os.path.join(output, f)
226
-
227
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
228
-
229
- if '_(Drums)_' in f: drums = path
230
- elif '_(Bass)_' in f: bass = path
231
- elif '_(Other)_' in f: other = path
232
- elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
233
-
234
- AudioSegment.from_file(drums).overlay(AudioSegment.from_file(bass)).overlay(AudioSegment.from_file(other)).export(os.path.join(output, f"Instruments.{format}"), format=format)
235
-
236
- for f in [drums, bass, other]:
237
- if os.path.exists(f): os.remove(f)
238
-
239
- logger.info(translations["separator_success_2"])
240
- return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
241
-
242
- def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size):
243
- if not os.path.exists(input):
244
- logger.warning(translations["input_not_valid"])
245
- sys.exit(1)
246
-
247
- if not os.path.exists(output):
248
- logger.warning(translations["output_not_valid"])
249
- sys.exit(1)
250
-
251
- for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
252
- if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
253
-
254
- model_2 = kara_models.get(kara_model)
255
-
256
- logger.info(f"{translations['separator_process_backing']}...")
257
- backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
258
-
259
- main_output = os.path.join(output, f"Main_Vocals.{format}")
260
- backing_output = os.path.join(output, f"Backing_Vocals.{format}")
261
-
262
- for f in backing_outputs:
263
- path = os.path.join(output, f)
264
-
265
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
266
-
267
- if '_(Instrumental)_' in f: os.rename(path, backing_output)
268
- elif '_(Vocals)_' in f: os.rename(path, main_output)
269
-
270
- logger.info(translations["separator_process_backing_success"])
271
- return main_output, backing_output
272
-
273
- def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size):
274
- if not os.path.exists(input):
275
- logger.warning(translations["input_not_valid"])
276
- sys.exit(1)
277
-
278
- if not os.path.exists(output):
279
- logger.warning(translations["output_not_valid"])
280
- sys.exit(1)
281
-
282
- for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
283
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
284
-
285
- model_3 = mdx_models.get(mdx_model)
286
-
287
- logger.info(f"{translations['separator_process_2']}...")
288
- output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
289
-
290
- original_output = os.path.join(output, f"Original_Vocals.{format}")
291
- instruments_output = os.path.join(output, f"Instruments.{format}")
292
-
293
- for f in output_music:
294
- path = os.path.join(output, f)
295
-
296
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
297
-
298
- if '_(Instrumental)_' in f: os.rename(path, instruments_output)
299
- elif '_(Vocals)_' in f: os.rename(path, original_output)
300
-
301
- logger.info(translations["separator_process_backing_success"])
302
- return original_output, instruments_output
303
-
304
- def separator_reverb(output, format, segments_size, overlap, denoise, original, main, backing_reverb, hop_length, batch_size):
305
- if not os.path.exists(output):
306
- logger.warning(translations["output_not_valid"])
307
- sys.exit(1)
308
-
309
- for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
310
- if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
311
-
312
- dereveb_path = []
313
-
314
- if original:
315
- try:
316
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
317
- except IndexError:
318
- logger.warning(translations["not_found_original_vocal"])
319
- sys.exit(1)
320
-
321
- if main:
322
- try:
323
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
324
- except IndexError:
325
- logger.warning(translations["not_found_main_vocal"])
326
- sys.exit(1)
327
-
328
- if backing_reverb:
329
- try:
330
- dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
331
- except IndexError:
332
- logger.warning(translations["not_found_backing_vocal"])
333
- sys.exit(1)
334
-
335
- for path in dereveb_path:
336
- if not os.path.exists(path):
337
- logger.warning(translations["not_found"].format(name=path))
338
- sys.exit(1)
339
-
340
- if "Original_Vocals" in path:
341
- reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}")
342
- no_reverb_path = os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
343
- start_title = translations["process_original"]
344
- end_title = translations["process_original_success"]
345
- elif "Main_Vocals" in path:
346
- reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}")
347
- no_reverb_path = os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
348
- start_title = translations["process_main"]
349
- end_title = translations["process_main_success"]
350
- elif "Backing_Vocals" in path:
351
- reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}")
352
- no_reverb_path = os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
353
- start_title = translations["process_backing"]
354
- end_title = translations["process_backing_success"]
355
-
356
- logger.info(start_title)
357
- output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
358
-
359
- for f in output_dereveb:
360
- path = os.path.join(output, f)
361
-
362
- if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
363
-
364
- if '_(Reverb)_' in f: os.rename(path, reverb_path)
365
- elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
366
-
367
- logger.info(end_title)
368
-
369
- return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if main else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
370
-
371
- def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25):
372
- separator = Separator(
373
- log_formatter=file_formatter,
374
- log_level=logging.INFO,
375
- output_dir=output_dir,
376
- output_format=output_format,
377
- output_bitrate=None,
378
- normalization_threshold=0.9,
379
- output_single_stem=None,
380
- invert_using_spec=False,
381
- sample_rate=44100,
382
- mdx_params={
383
- "hop_length": mdx_hop_length,
384
- "segment_size": mdx_segment_size,
385
- "overlap": mdx_overlap,
386
- "batch_size": mdx_batch_size,
387
- "enable_denoise": mdx_enable_denoise,
388
- },
389
- demucs_params={
390
- "segment_size": demucs_segment_size,
391
- "shifts": demucs_shifts,
392
- "overlap": demucs_overlap,
393
- "segments_enabled": True,
394
- }
395
- )
396
-
397
- separator.load_model(model_filename=model_filename)
398
- return separator.separate(audio_file)
399
-
400
- if __name__ == "__main__": main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/inference/train.py DELETED
@@ -1,1600 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import glob
5
- import json
6
- import torch
7
- import logging
8
- import hashlib
9
- import argparse
10
- import datetime
11
- import warnings
12
- import logging.handlers
13
-
14
- import numpy as np
15
- import torch.utils.data
16
- import matplotlib.pyplot as plt
17
- import torch.distributed as dist
18
- import torch.multiprocessing as mp
19
-
20
- from tqdm import tqdm
21
- from time import time as ttime
22
- from scipy.io.wavfile import read
23
- from collections import OrderedDict
24
- from random import randint, shuffle
25
-
26
- from torch.nn import functional as F
27
- from distutils.util import strtobool
28
- from torch.utils.data import DataLoader
29
- from torch.cuda.amp import GradScaler, autocast
30
- from torch.utils.tensorboard import SummaryWriter
31
- from librosa.filters import mel as librosa_mel_fn
32
- from torch.nn.parallel import DistributedDataParallel as DDP
33
- from torch.nn.utils.parametrizations import spectral_norm, weight_norm
34
-
35
- current_dir = os.getcwd()
36
- sys.path.append(current_dir)
37
-
38
- from main.configs.config import Config
39
- from main.library.algorithm.residuals import LRELU_SLOPE
40
- from main.library.algorithm.synthesizers import Synthesizer
41
- from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
42
-
43
- warnings.filterwarnings("ignore", category=FutureWarning)
44
- warnings.filterwarnings("ignore", category=UserWarning)
45
-
46
- logging.getLogger("torch").setLevel(logging.ERROR)
47
-
48
- MATPLOTLIB_FLAG = False
49
-
50
- translations = Config().translations
51
-
52
- class HParams:
53
- def __init__(self, **kwargs):
54
- for k, v in kwargs.items():
55
- self[k] = HParams(**v) if isinstance(v, dict) else v
56
-
57
-
58
- def keys(self):
59
- return self.__dict__.keys()
60
-
61
-
62
- def items(self):
63
- return self.__dict__.items()
64
-
65
-
66
- def values(self):
67
- return self.__dict__.values()
68
-
69
-
70
- def __len__(self):
71
- return len(self.__dict__)
72
-
73
-
74
- def __getitem__(self, key):
75
- return self.__dict__[key]
76
-
77
-
78
- def __setitem__(self, key, value):
79
- self.__dict__[key] = value
80
-
81
-
82
- def __contains__(self, key):
83
- return key in self.__dict__
84
-
85
-
86
- def __repr__(self):
87
- return repr(self.__dict__)
88
-
89
- def parse_arguments() -> tuple:
90
- parser = argparse.ArgumentParser()
91
- parser.add_argument("--model_name", type=str, required=True)
92
- parser.add_argument("--rvc_version", type=str, default="v2")
93
- parser.add_argument("--save_every_epoch", type=int, required=True)
94
- parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
95
- parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
96
- parser.add_argument("--total_epoch", type=int, default=300)
97
- parser.add_argument("--sample_rate", type=int, required=True)
98
- parser.add_argument("--batch_size", type=int, default=8)
99
- parser.add_argument("--gpu", type=str, default="0")
100
- parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
101
- parser.add_argument("--g_pretrained_path", type=str, default="")
102
- parser.add_argument("--d_pretrained_path", type=str, default="")
103
- parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
104
- parser.add_argument("--overtraining_threshold", type=int, default=50)
105
- parser.add_argument("--sync_graph", type=lambda x: bool(strtobool(x)), default=False)
106
- parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
107
- parser.add_argument("--model_author", type=str)
108
-
109
- args = parser.parse_args()
110
- return args
111
-
112
-
113
- args = parse_arguments()
114
-
115
- model_name = args.model_name
116
- save_every_epoch = args.save_every_epoch
117
- total_epoch = args.total_epoch
118
- pretrainG = args.g_pretrained_path
119
- pretrainD = args.d_pretrained_path
120
- version = args.rvc_version
121
- gpus = args.gpu
122
- batch_size = args.batch_size
123
- sample_rate = args.sample_rate
124
- pitch_guidance = args.pitch_guidance
125
- save_only_latest = args.save_only_latest
126
- save_every_weights = args.save_every_weights
127
- cache_data_in_gpu = args.cache_data_in_gpu
128
- overtraining_detector = args.overtraining_detector
129
- overtraining_threshold = args.overtraining_threshold
130
- sync_graph = args.sync_graph
131
- model_author = args.model_author
132
-
133
- experiment_dir = os.path.join(current_dir, "assets", "logs", model_name)
134
- config_save_path = os.path.join(experiment_dir, "config.json")
135
-
136
- os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
137
- n_gpus = len(gpus.split("-"))
138
-
139
- torch.backends.cudnn.deterministic = False
140
- torch.backends.cudnn.benchmark = False
141
-
142
- global_step = 0
143
- last_loss_gen_all = 0
144
- overtrain_save_epoch = 0
145
-
146
- loss_gen_history = []
147
- smoothed_loss_gen_history = []
148
- loss_disc_history = []
149
- smoothed_loss_disc_history = []
150
-
151
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
152
- training_file_path = os.path.join(experiment_dir, "training_data.json")
153
-
154
- with open(config_save_path, "r") as f:
155
- config = json.load(f)
156
-
157
- config = HParams(**config)
158
- config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
159
-
160
- log_file = os.path.join(experiment_dir, "train.log")
161
-
162
- logger = logging.getLogger(__name__)
163
-
164
- if logger.hasHandlers(): logger.handlers.clear()
165
- else:
166
- console_handler = logging.StreamHandler()
167
- console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
168
-
169
- console_handler.setFormatter(console_formatter)
170
- console_handler.setLevel(logging.INFO)
171
-
172
- file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
173
- file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
174
-
175
- file_handler.setFormatter(file_formatter)
176
- file_handler.setLevel(logging.DEBUG)
177
-
178
- logger.addHandler(console_handler)
179
- logger.addHandler(file_handler)
180
- logger.setLevel(logging.DEBUG)
181
-
182
- logger.debug(f"{translations['modelname']}: {model_name}")
183
- logger.debug(translations["save_every_epoch"].format(save_every_epoch=save_every_epoch))
184
- logger.debug(translations["total_e"].format(total_epoch=total_epoch))
185
- logger.debug(translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD))
186
- logger.debug(f"{translations['training_version']}: {version}")
187
- logger.debug(f"Gpu: {gpus}")
188
- logger.debug(f"{translations['batch_size']}: {batch_size}")
189
- logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
190
- logger.debug(f"{translations['training_f0']}: {pitch_guidance}")
191
- logger.debug(f"{translations['save_only_latest']}: {save_only_latest}")
192
- logger.debug(f"{translations['save_every_weights']}: {save_every_weights}")
193
- logger.debug(f"{translations['cache_in_gpu']}: {cache_data_in_gpu}")
194
- logger.debug(f"{translations['overtraining_detector']}: {overtraining_detector}")
195
- logger.debug(f"{translations['threshold']}: {overtraining_threshold}")
196
- logger.debug(f"{translations['sync_graph']}: {sync_graph}")
197
- if not model_author: logger.debug(translations["model_author"].format(model_author=model_author))
198
-
199
- def main():
200
- global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author
201
- os.environ["MASTER_ADDR"] = "localhost"
202
- os.environ["MASTER_PORT"] = str(randint(20000, 55555))
203
-
204
- if torch.cuda.is_available():
205
- device = torch.device("cuda")
206
- n_gpus = torch.cuda.device_count()
207
- elif torch.backends.mps.is_available():
208
- device = torch.device("mps")
209
- n_gpus = 1
210
- else:
211
- device = torch.device("cpu")
212
- n_gpus = 1
213
-
214
- def start():
215
- children = []
216
-
217
- for i in range(n_gpus):
218
- subproc = mp.Process(target=run, args=(i, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author))
219
- children.append(subproc)
220
- subproc.start()
221
-
222
- for i in range(n_gpus):
223
- children[i].join()
224
-
225
- def load_from_json(file_path):
226
- if os.path.exists(file_path):
227
- with open(file_path, "r") as f:
228
- data = json.load(f)
229
-
230
- return (
231
- data.get("loss_disc_history", []),
232
- data.get("smoothed_loss_disc_history", []),
233
- data.get("loss_gen_history", []),
234
- data.get("smoothed_loss_gen_history", []),
235
- )
236
-
237
- return [], [], [], []
238
-
239
- def continue_overtrain_detector(training_file_path):
240
- if overtraining_detector:
241
- if os.path.exists(training_file_path):
242
- (
243
- loss_disc_history,
244
- smoothed_loss_disc_history,
245
- loss_gen_history,
246
- smoothed_loss_gen_history,
247
- ) = load_from_json(training_file_path)
248
-
249
-
250
- n_gpus = torch.cuda.device_count()
251
-
252
- if not torch.cuda.is_available() and torch.backends.mps.is_available(): n_gpus = 1
253
-
254
- if n_gpus < 1:
255
- logger.warning(translations["not_gpu"])
256
- n_gpus = 1
257
-
258
- if sync_graph:
259
- logger.debug(translations["sync"])
260
- custom_total_epoch = 1
261
- custom_save_every_weights = True
262
-
263
- start()
264
-
265
- model_config_file = os.path.join(experiment_dir, "config.json")
266
- rvc_config_file = os.path.join(current_dir, "main", "configs", version, str(sample_rate) + ".json")
267
-
268
- if not os.path.exists(rvc_config_file): rvc_config_file = os.path.join(current_dir, "main", "configs", "v1", str(sample_rate) + ".json")
269
-
270
- pattern = rf"{os.path.basename(model_name)}_(\d+)e_(\d+)s\.pth"
271
-
272
- for filename in os.listdir(os.path.join("assets", "weights")):
273
- match = re.match(pattern, filename)
274
-
275
- if match: steps = int(match.group(2))
276
-
277
- def edit_config(config_file):
278
- with open(config_file, "r", encoding="utf8") as json_file:
279
- config_data = json.load(json_file)
280
-
281
- config_data["train"]["log_interval"] = steps
282
-
283
- with open(config_file, "w", encoding="utf8") as json_file:
284
- json.dump(config_data, json_file, indent=2, separators=(",", ": "), ensure_ascii=False)
285
-
286
- edit_config(model_config_file)
287
- edit_config(rvc_config_file)
288
-
289
- for root, dirs, files in os.walk(experiment_dir, topdown=False):
290
- for name in files:
291
- file_path = os.path.join(root, name)
292
- _, file_extension = os.path.splitext(name)
293
-
294
- if file_extension == ".0": os.remove(file_path)
295
- elif ("D" in name or "G" in name) and file_extension == ".pth": os.remove(file_path)
296
- elif ("added" in name or "trained" in name) and file_extension == ".index": os.remove(file_path)
297
-
298
- for name in dirs:
299
- if name == "eval":
300
- folder_path = os.path.join(root, name)
301
-
302
- for item in os.listdir(folder_path):
303
- item_path = os.path.join(folder_path, item)
304
- if os.path.isfile(item_path): os.remove(item_path)
305
-
306
- os.rmdir(folder_path)
307
-
308
- logger.info(translations["sync_success"])
309
- custom_total_epoch = total_epoch
310
- custom_save_every_weights = save_every_weights
311
-
312
- continue_overtrain_detector(training_file_path)
313
- start()
314
- else:
315
- custom_total_epoch = total_epoch
316
- custom_save_every_weights = save_every_weights
317
-
318
- continue_overtrain_detector(training_file_path)
319
- start()
320
-
321
-
322
- def plot_spectrogram_to_numpy(spectrogram):
323
- global MATPLOTLIB_FLAG
324
-
325
- if not MATPLOTLIB_FLAG:
326
- plt.switch_backend("Agg")
327
- MATPLOTLIB_FLAG = True
328
-
329
- fig, ax = plt.subplots(figsize=(10, 2))
330
- im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
331
-
332
- plt.colorbar(im, ax=ax)
333
- plt.xlabel("Frames")
334
- plt.ylabel("Channels")
335
- plt.tight_layout()
336
-
337
- fig.canvas.draw()
338
- data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
339
- data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
340
- plt.close(fig)
341
- return data
342
-
343
-
344
- def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
345
- for k, v in scalars.items():
346
- writer.add_scalar(k, v, global_step)
347
-
348
- for k, v in histograms.items():
349
- writer.add_histogram(k, v, global_step)
350
-
351
- for k, v in images.items():
352
- writer.add_image(k, v, global_step, dataformats="HWC")
353
-
354
- for k, v in audios.items():
355
- writer.add_audio(k, v, global_step, audio_sample_rate)
356
-
357
-
358
- def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
359
- assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
360
-
361
- checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
362
- checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(checkpoint_dict, ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
363
-
364
- model_state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
365
- new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in model_state_dict.items()}
366
-
367
- if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
368
- else: model.load_state_dict(new_state_dict, strict=False)
369
-
370
- if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
371
-
372
- logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
373
-
374
- return (
375
- model,
376
- optimizer,
377
- checkpoint_dict.get("learning_rate", 0),
378
- checkpoint_dict["iteration"],
379
- )
380
-
381
-
382
- def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
383
- state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
384
-
385
- checkpoint_data = {
386
- "model": state_dict,
387
- "iteration": iteration,
388
- "optimizer": optimizer.state_dict(),
389
- "learning_rate": learning_rate,
390
- }
391
-
392
- torch.save(checkpoint_data, checkpoint_path)
393
-
394
- old_version_path = checkpoint_path.replace(".pth", "_old_version.pth")
395
- checkpoint_data = replace_keys_in_dict(replace_keys_in_dict(checkpoint_data, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g")
396
-
397
- torch.save(checkpoint_data, old_version_path)
398
-
399
- os.replace(old_version_path, checkpoint_path)
400
- logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
401
-
402
-
403
- def latest_checkpoint_path(dir_path, regex="G_*.pth"):
404
- checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
405
-
406
- return checkpoints[-1] if checkpoints else None
407
-
408
-
409
- def load_wav_to_torch(full_path):
410
- sample_rate, data = read(full_path)
411
-
412
- return torch.FloatTensor(data.astype(np.float32)), sample_rate
413
-
414
-
415
- def load_filepaths_and_text(filename, split="|"):
416
- with open(filename, encoding="utf-8") as f:
417
- return [line.strip().split(split) for line in f]
418
-
419
-
420
- def feature_loss(fmap_r, fmap_g):
421
- loss = 0
422
-
423
- for dr, dg in zip(fmap_r, fmap_g):
424
- for rl, gl in zip(dr, dg):
425
- rl = rl.float().detach()
426
- gl = gl.float()
427
- loss += torch.mean(torch.abs(rl - gl))
428
-
429
- return loss * 2
430
-
431
-
432
- def discriminator_loss(disc_real_outputs, disc_generated_outputs):
433
- loss = 0
434
- r_losses = []
435
- g_losses = []
436
-
437
- for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
438
- dr = dr.float()
439
- dg = dg.float()
440
- r_loss = torch.mean((1 - dr) ** 2)
441
- g_loss = torch.mean(dg**2)
442
- loss += r_loss + g_loss
443
- r_losses.append(r_loss.item())
444
- g_losses.append(g_loss.item())
445
-
446
- return loss, r_losses, g_losses
447
-
448
-
449
- def generator_loss(disc_outputs):
450
- loss = 0
451
- gen_losses = []
452
-
453
- for dg in disc_outputs:
454
- dg = dg.float()
455
- l = torch.mean((1 - dg) ** 2)
456
- gen_losses.append(l)
457
- loss += l
458
-
459
- return loss, gen_losses
460
-
461
-
462
- def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
463
- z_p = z_p.float()
464
- logs_q = logs_q.float()
465
- m_p = m_p.float()
466
- logs_p = logs_p.float()
467
- z_mask = z_mask.float()
468
-
469
- kl = logs_p - logs_q - 0.5
470
- kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
471
- kl = torch.sum(kl * z_mask)
472
- l = kl / torch.sum(z_mask)
473
-
474
- return l
475
-
476
-
477
- class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
478
- def __init__(self, hparams):
479
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
480
- self.max_wav_value = hparams.max_wav_value
481
- self.sample_rate = hparams.sample_rate
482
- self.filter_length = hparams.filter_length
483
- self.hop_length = hparams.hop_length
484
- self.win_length = hparams.win_length
485
- self.sample_rate = hparams.sample_rate
486
- self.min_text_len = getattr(hparams, "min_text_len", 1)
487
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
488
- self._filter()
489
-
490
- def _filter(self):
491
- audiopaths_and_text_new = []
492
- lengths = []
493
-
494
- for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
495
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
496
- audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
497
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
498
-
499
- self.audiopaths_and_text = audiopaths_and_text_new
500
- self.lengths = lengths
501
-
502
- def get_sid(self, sid):
503
- try:
504
- sid = torch.LongTensor([int(sid)])
505
- except ValueError as e:
506
- logger.error(translations["sid_error"].format(sid=sid, e=e))
507
- sid = torch.LongTensor([0])
508
-
509
- return sid
510
-
511
- def get_audio_text_pair(self, audiopath_and_text):
512
- file = audiopath_and_text[0]
513
- phone = audiopath_and_text[1]
514
- pitch = audiopath_and_text[2]
515
- pitchf = audiopath_and_text[3]
516
- dv = audiopath_and_text[4]
517
-
518
- phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
519
- spec, wav = self.get_audio(file)
520
- dv = self.get_sid(dv)
521
-
522
- len_phone = phone.size()[0]
523
- len_spec = spec.size()[-1]
524
- if len_phone != len_spec:
525
- len_min = min(len_phone, len_spec)
526
- len_wav = len_min * self.hop_length
527
-
528
- spec = spec[:, :len_min]
529
- wav = wav[:, :len_wav]
530
-
531
- phone = phone[:len_min, :]
532
- pitch = pitch[:len_min]
533
- pitchf = pitchf[:len_min]
534
-
535
- return (spec, wav, phone, pitch, pitchf, dv)
536
-
537
- def get_labels(self, phone, pitch, pitchf):
538
- phone = np.load(phone)
539
- phone = np.repeat(phone, 2, axis=0)
540
-
541
- pitch = np.load(pitch)
542
- pitchf = np.load(pitchf)
543
-
544
- n_num = min(phone.shape[0], 900)
545
- phone = phone[:n_num, :]
546
-
547
- pitch = pitch[:n_num]
548
- pitchf = pitchf[:n_num]
549
-
550
- phone = torch.FloatTensor(phone)
551
-
552
- pitch = torch.LongTensor(pitch)
553
- pitchf = torch.FloatTensor(pitchf)
554
-
555
- return phone, pitch, pitchf
556
-
557
- def get_audio(self, filename):
558
- audio, sample_rate = load_wav_to_torch(filename)
559
-
560
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
561
-
562
- audio_norm = audio
563
- audio_norm = audio_norm.unsqueeze(0)
564
- spec_filename = filename.replace(".wav", ".spec.pt")
565
-
566
- if os.path.exists(spec_filename):
567
- try:
568
- spec = torch.load(spec_filename)
569
- except Exception as e:
570
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
571
- spec = spectrogram_torch(
572
- audio_norm,
573
- self.filter_length,
574
- self.hop_length,
575
- self.win_length,
576
- center=False,
577
- )
578
- spec = torch.squeeze(spec, 0)
579
-
580
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
581
- else:
582
- spec = spectrogram_torch(
583
- audio_norm,
584
- self.filter_length,
585
- self.hop_length,
586
- self.win_length,
587
- center=False,
588
- )
589
- spec = torch.squeeze(spec, 0)
590
-
591
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
592
- return spec, audio_norm
593
-
594
- def __getitem__(self, index):
595
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
596
-
597
- def __len__(self):
598
- return len(self.audiopaths_and_text)
599
-
600
-
601
- class TextAudioCollateMultiNSFsid:
602
- def __init__(self, return_ids=False):
603
- self.return_ids = return_ids
604
-
605
- def __call__(self, batch):
606
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
607
-
608
- max_spec_len = max([x[0].size(1) for x in batch])
609
- max_wave_len = max([x[1].size(1) for x in batch])
610
-
611
- spec_lengths = torch.LongTensor(len(batch))
612
- wave_lengths = torch.LongTensor(len(batch))
613
- spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
614
- wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
615
-
616
- spec_padded.zero_()
617
- wave_padded.zero_()
618
-
619
- max_phone_len = max([x[2].size(0) for x in batch])
620
-
621
- phone_lengths = torch.LongTensor(len(batch))
622
- phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
623
- pitch_padded = torch.LongTensor(len(batch), max_phone_len)
624
- pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
625
-
626
- phone_padded.zero_()
627
- pitch_padded.zero_()
628
- pitchf_padded.zero_()
629
- sid = torch.LongTensor(len(batch))
630
-
631
- for i in range(len(ids_sorted_decreasing)):
632
- row = batch[ids_sorted_decreasing[i]]
633
-
634
- spec = row[0]
635
- spec_padded[i, :, : spec.size(1)] = spec
636
- spec_lengths[i] = spec.size(1)
637
-
638
- wave = row[1]
639
- wave_padded[i, :, : wave.size(1)] = wave
640
- wave_lengths[i] = wave.size(1)
641
-
642
- phone = row[2]
643
- phone_padded[i, : phone.size(0), :] = phone
644
- phone_lengths[i] = phone.size(0)
645
-
646
- pitch = row[3]
647
- pitch_padded[i, : pitch.size(0)] = pitch
648
- pitchf = row[4]
649
- pitchf_padded[i, : pitchf.size(0)] = pitchf
650
-
651
- sid[i] = row[5]
652
-
653
- return (
654
- phone_padded,
655
- phone_lengths,
656
- pitch_padded,
657
- pitchf_padded,
658
- spec_padded,
659
- spec_lengths,
660
- wave_padded,
661
- wave_lengths,
662
- sid,
663
- )
664
-
665
-
666
- class TextAudioLoader(torch.utils.data.Dataset):
667
- def __init__(self, hparams):
668
- self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
669
- self.max_wav_value = hparams.max_wav_value
670
- self.sample_rate = hparams.sample_rate
671
- self.filter_length = hparams.filter_length
672
- self.hop_length = hparams.hop_length
673
- self.win_length = hparams.win_length
674
- self.sample_rate = hparams.sample_rate
675
- self.min_text_len = getattr(hparams, "min_text_len", 1)
676
- self.max_text_len = getattr(hparams, "max_text_len", 5000)
677
- self._filter()
678
-
679
- def _filter(self):
680
- audiopaths_and_text_new = []
681
- lengths = []
682
-
683
- for entry in self.audiopaths_and_text:
684
- if len(entry) >= 3:
685
- audiopath, text, dv = entry[:3]
686
-
687
- if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
688
- audiopaths_and_text_new.append([audiopath, text, dv])
689
- lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
690
-
691
- self.audiopaths_and_text = audiopaths_and_text_new
692
- self.lengths = lengths
693
-
694
- def get_sid(self, sid):
695
- try:
696
- sid = torch.LongTensor([int(sid)])
697
- except ValueError as e:
698
- logger.error(translations["sid_error"].format(sid=sid, e=e))
699
- sid = torch.LongTensor([0])
700
-
701
- return sid
702
-
703
- def get_audio_text_pair(self, audiopath_and_text):
704
- file = audiopath_and_text[0]
705
- phone = audiopath_and_text[1]
706
- dv = audiopath_and_text[2]
707
-
708
- phone = self.get_labels(phone)
709
- spec, wav = self.get_audio(file)
710
- dv = self.get_sid(dv)
711
-
712
- len_phone = phone.size()[0]
713
- len_spec = spec.size()[-1]
714
-
715
- if len_phone != len_spec:
716
- len_min = min(len_phone, len_spec)
717
- len_wav = len_min * self.hop_length
718
- spec = spec[:, :len_min]
719
- wav = wav[:, :len_wav]
720
- phone = phone[:len_min, :]
721
-
722
- return (spec, wav, phone, dv)
723
-
724
- def get_labels(self, phone):
725
- phone = np.load(phone)
726
- phone = np.repeat(phone, 2, axis=0)
727
- n_num = min(phone.shape[0], 900)
728
- phone = phone[:n_num, :]
729
- phone = torch.FloatTensor(phone)
730
- return phone
731
-
732
- def get_audio(self, filename):
733
- audio, sample_rate = load_wav_to_torch(filename)
734
-
735
- if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
736
-
737
- audio_norm = audio
738
- audio_norm = audio_norm.unsqueeze(0)
739
-
740
- spec_filename = filename.replace(".wav", ".spec.pt")
741
-
742
- if os.path.exists(spec_filename):
743
- try:
744
- spec = torch.load(spec_filename)
745
- except Exception as e:
746
- logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
747
- spec = spectrogram_torch(
748
- audio_norm,
749
- self.filter_length,
750
- self.hop_length,
751
- self.win_length,
752
- center=False,
753
- )
754
- spec = torch.squeeze(spec, 0)
755
-
756
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
757
- else:
758
- spec = spectrogram_torch(
759
- audio_norm,
760
- self.filter_length,
761
- self.hop_length,
762
- self.win_length,
763
- center=False,
764
- )
765
- spec = torch.squeeze(spec, 0)
766
-
767
- torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
768
- return spec, audio_norm
769
-
770
- def __getitem__(self, index):
771
- return self.get_audio_text_pair(self.audiopaths_and_text[index])
772
-
773
- def __len__(self):
774
- return len(self.audiopaths_and_text)
775
-
776
-
777
- class TextAudioCollate:
778
- def __init__(self, return_ids=False):
779
- self.return_ids = return_ids
780
-
781
- def __call__(self, batch):
782
- _, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
783
-
784
- max_spec_len = max([x[0].size(1) for x in batch])
785
- max_wave_len = max([x[1].size(1) for x in batch])
786
-
787
- spec_lengths = torch.LongTensor(len(batch))
788
- wave_lengths = torch.LongTensor(len(batch))
789
-
790
- spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
791
- wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
792
-
793
- spec_padded.zero_()
794
- wave_padded.zero_()
795
-
796
- max_phone_len = max([x[2].size(0) for x in batch])
797
-
798
- phone_lengths = torch.LongTensor(len(batch))
799
- phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
800
-
801
- phone_padded.zero_()
802
- sid = torch.LongTensor(len(batch))
803
-
804
- for i in range(len(ids_sorted_decreasing)):
805
- row = batch[ids_sorted_decreasing[i]]
806
-
807
- spec = row[0]
808
- spec_padded[i, :, : spec.size(1)] = spec
809
- spec_lengths[i] = spec.size(1)
810
-
811
- wave = row[1]
812
- wave_padded[i, :, : wave.size(1)] = wave
813
- wave_lengths[i] = wave.size(1)
814
-
815
- phone = row[2]
816
- phone_padded[i, : phone.size(0), :] = phone
817
- phone_lengths[i] = phone.size(0)
818
-
819
- sid[i] = row[3]
820
-
821
- return (
822
- phone_padded,
823
- phone_lengths,
824
- spec_padded,
825
- spec_lengths,
826
- wave_padded,
827
- wave_lengths,
828
- sid,
829
- )
830
-
831
-
832
- class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
833
- def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
834
- super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
835
- self.lengths = dataset.lengths
836
- self.batch_size = batch_size
837
- self.boundaries = boundaries
838
- self.buckets, self.num_samples_per_bucket = self._create_buckets()
839
- self.total_size = sum(self.num_samples_per_bucket)
840
- self.num_samples = self.total_size // self.num_replicas
841
-
842
- def _create_buckets(self):
843
- buckets = [[] for _ in range(len(self.boundaries) - 1)]
844
-
845
- for i in range(len(self.lengths)):
846
- length = self.lengths[i]
847
- idx_bucket = self._bisect(length)
848
- if idx_bucket != -1: buckets[idx_bucket].append(i)
849
-
850
- for i in range(len(buckets) - 1, -1, -1):
851
- if len(buckets[i]) == 0:
852
- buckets.pop(i)
853
- self.boundaries.pop(i + 1)
854
-
855
- num_samples_per_bucket = []
856
-
857
- for i in range(len(buckets)):
858
- len_bucket = len(buckets[i])
859
- total_batch_size = self.num_replicas * self.batch_size
860
- rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
861
- num_samples_per_bucket.append(len_bucket + rem)
862
-
863
- return buckets, num_samples_per_bucket
864
-
865
- def __iter__(self):
866
- g = torch.Generator()
867
- g.manual_seed(self.epoch)
868
-
869
- indices = []
870
-
871
- if self.shuffle:
872
- for bucket in self.buckets:
873
- indices.append(torch.randperm(len(bucket), generator=g).tolist())
874
- else:
875
- for bucket in self.buckets:
876
- indices.append(list(range(len(bucket))))
877
-
878
- batches = []
879
-
880
- for i in range(len(self.buckets)):
881
- bucket = self.buckets[i]
882
- len_bucket = len(bucket)
883
- ids_bucket = indices[i]
884
- num_samples_bucket = self.num_samples_per_bucket[i]
885
-
886
- rem = num_samples_bucket - len_bucket
887
-
888
- ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])
889
- ids_bucket = ids_bucket[self.rank :: self.num_replicas]
890
-
891
- for j in range(len(ids_bucket) // self.batch_size):
892
- batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
893
- batches.append(batch)
894
-
895
- if self.shuffle:
896
- batch_ids = torch.randperm(len(batches), generator=g).tolist()
897
- batches = [batches[i] for i in batch_ids]
898
-
899
- self.batches = batches
900
-
901
- assert len(self.batches) * self.batch_size == self.num_samples
902
- return iter(self.batches)
903
-
904
- def _bisect(self, x, lo=0, hi=None):
905
- if hi is None: hi = len(self.boundaries) - 1
906
-
907
- if hi > lo:
908
- mid = (hi + lo) // 2
909
-
910
- if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
911
- elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
912
- else: return self._bisect(x, mid + 1, hi)
913
-
914
- else: return -1
915
-
916
- def __len__(self):
917
- return self.num_samples // self.batch_size
918
-
919
-
920
- class MultiPeriodDiscriminator(torch.nn.Module):
921
- def __init__(self, use_spectral_norm=False):
922
- super(MultiPeriodDiscriminator, self).__init__()
923
- periods = [2, 3, 5, 7, 11, 17]
924
- self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods])
925
-
926
- def forward(self, y, y_hat):
927
- y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
928
-
929
- for d in self.discriminators:
930
- y_d_r, fmap_r = d(y)
931
- y_d_g, fmap_g = d(y_hat)
932
- y_d_rs.append(y_d_r)
933
- y_d_gs.append(y_d_g)
934
- fmap_rs.append(fmap_r)
935
- fmap_gs.append(fmap_g)
936
-
937
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
938
-
939
-
940
- class MultiPeriodDiscriminatorV2(torch.nn.Module):
941
- def __init__(self, use_spectral_norm=False):
942
- super(MultiPeriodDiscriminatorV2, self).__init__()
943
- periods = [2, 3, 5, 7, 11, 17, 23, 37]
944
- self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods])
945
-
946
- def forward(self, y, y_hat):
947
- y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
948
-
949
- for d in self.discriminators:
950
- y_d_r, fmap_r = d(y)
951
- y_d_g, fmap_g = d(y_hat)
952
- y_d_rs.append(y_d_r)
953
- y_d_gs.append(y_d_g)
954
- fmap_rs.append(fmap_r)
955
- fmap_gs.append(fmap_g)
956
-
957
- return y_d_rs, y_d_gs, fmap_rs, fmap_gs
958
-
959
-
960
- class DiscriminatorS(torch.nn.Module):
961
- def __init__(self, use_spectral_norm=False):
962
- super(DiscriminatorS, self).__init__()
963
- norm_f = spectral_norm if use_spectral_norm else weight_norm
964
-
965
- self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
966
- self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
967
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
968
-
969
- def forward(self, x):
970
- fmap = []
971
-
972
- for conv in self.convs:
973
- x = self.lrelu(conv(x))
974
- fmap.append(x)
975
-
976
- x = self.conv_post(x)
977
- fmap.append(x)
978
- x = torch.flatten(x, 1, -1)
979
-
980
- return x, fmap
981
-
982
-
983
- class DiscriminatorP(torch.nn.Module):
984
- def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
985
- super(DiscriminatorP, self).__init__()
986
- self.period = period
987
- norm_f = spectral_norm if use_spectral_norm else weight_norm
988
-
989
- in_channels = [1, 32, 128, 512, 1024]
990
- out_channels = [32, 128, 512, 1024, 1024]
991
-
992
- self.convs = torch.nn.ModuleList(
993
- [
994
- norm_f(
995
- torch.nn.Conv2d(
996
- in_ch,
997
- out_ch,
998
- (kernel_size, 1),
999
- (stride, 1),
1000
- padding=(get_padding(kernel_size, 1), 0),
1001
- )
1002
- )
1003
- for in_ch, out_ch in zip(in_channels, out_channels)
1004
- ]
1005
- )
1006
-
1007
- self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
1008
- self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
1009
-
1010
- def forward(self, x):
1011
- fmap = []
1012
- b, c, t = x.shape
1013
-
1014
- if t % self.period != 0:
1015
- n_pad = self.period - (t % self.period)
1016
- x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
1017
-
1018
- x = x.view(b, c, -1, self.period)
1019
-
1020
- for conv in self.convs:
1021
- x = self.lrelu(conv(x))
1022
- fmap.append(x)
1023
-
1024
- x = self.conv_post(x)
1025
- fmap.append(x)
1026
- x = torch.flatten(x, 1, -1)
1027
- return x, fmap
1028
-
1029
-
1030
- class EpochRecorder:
1031
- def __init__(self):
1032
- self.last_time = ttime()
1033
-
1034
- def record(self):
1035
- now_time = ttime()
1036
- elapsed_time = now_time - self.last_time
1037
- self.last_time = now_time
1038
- elapsed_time = round(elapsed_time, 1)
1039
- elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
1040
- current_time = datetime.datetime.now().strftime("%H:%M:%S")
1041
- return translations["time_or_speed_training"].format(current_time=current_time, elapsed_time_str=elapsed_time_str)
1042
-
1043
-
1044
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
1045
- return torch.log(torch.clamp(x, min=clip_val) * C)
1046
-
1047
-
1048
- def dynamic_range_decompression_torch(x, C=1):
1049
- return torch.exp(x) / C
1050
-
1051
-
1052
- def spectral_normalize_torch(magnitudes):
1053
- return dynamic_range_compression_torch(magnitudes)
1054
-
1055
-
1056
- def spectral_de_normalize_torch(magnitudes):
1057
- return dynamic_range_decompression_torch(magnitudes)
1058
-
1059
-
1060
- mel_basis = {}
1061
- hann_window = {}
1062
-
1063
-
1064
- def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
1065
- global hann_window
1066
- dtype_device = str(y.dtype) + "_" + str(y.device)
1067
- wnsize_dtype_device = str(win_size) + "_" + dtype_device
1068
- if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
1069
-
1070
- y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect")
1071
-
1072
- y = y.squeeze(1)
1073
-
1074
- spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
1075
-
1076
- spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
1077
- return spec
1078
-
1079
-
1080
- def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
1081
- global mel_basis
1082
- dtype_device = str(spec.dtype) + "_" + str(spec.device)
1083
- fmax_dtype_device = str(fmax) + "_" + dtype_device
1084
-
1085
- if fmax_dtype_device not in mel_basis:
1086
- mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
1087
- mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
1088
-
1089
- melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
1090
- melspec = spectral_normalize_torch(melspec)
1091
- return melspec
1092
-
1093
-
1094
- def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
1095
- spec = spectrogram_torch(y, n_fft, hop_size, win_size, center)
1096
-
1097
- melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax)
1098
-
1099
- return melspec
1100
-
1101
-
1102
- def replace_keys_in_dict(d, old_key_part, new_key_part):
1103
- updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
1104
-
1105
- for key, value in d.items():
1106
- new_key = (key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)
1107
- updated_dict[new_key] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
1108
-
1109
- return updated_dict
1110
-
1111
-
1112
- def extract_model(ckpt, sr, pitch_guidance, name, model_dir, epoch, step, version, hps, model_author):
1113
- try:
1114
- logger.info(translations["savemodel"].format(model_dir=model_dir, epoch=epoch, step=step))
1115
-
1116
- model_dir_path = os.path.join("assets", "weights")
1117
-
1118
- if "best_epoch" in model_dir: pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth"
1119
- else: pth_file = f"{name}_{epoch}e_{step}s.pth"
1120
-
1121
- pth_file_old_version_path = os.path.join(model_dir_path, f"{pth_file}_old_version.pth")
1122
-
1123
- opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
1124
-
1125
- opt["config"] = [
1126
- hps.data.filter_length // 2 + 1,
1127
- 32,
1128
- hps.model.inter_channels,
1129
- hps.model.hidden_channels,
1130
- hps.model.filter_channels,
1131
- hps.model.n_heads,
1132
- hps.model.n_layers,
1133
- hps.model.kernel_size,
1134
- hps.model.p_dropout,
1135
- hps.model.resblock,
1136
- hps.model.resblock_kernel_sizes,
1137
- hps.model.resblock_dilation_sizes,
1138
- hps.model.upsample_rates,
1139
- hps.model.upsample_initial_channel,
1140
- hps.model.upsample_kernel_sizes,
1141
- hps.model.spk_embed_dim,
1142
- hps.model.gin_channels,
1143
- hps.data.sample_rate,
1144
- ]
1145
-
1146
- opt["epoch"] = f"{epoch}epoch"
1147
- opt["step"] = step
1148
- opt["sr"] = sr
1149
- opt["f0"] = int(pitch_guidance)
1150
- opt["version"] = version
1151
- opt["creation_date"] = datetime.datetime.now().isoformat()
1152
-
1153
- hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}"
1154
- model_hash = hashlib.sha256(hash_input.encode()).hexdigest()
1155
- opt["model_hash"] = model_hash
1156
- opt["model_name"] = name
1157
- opt["author"] = model_author
1158
-
1159
- torch.save(opt, os.path.join(model_dir_path, pth_file))
1160
-
1161
- model = torch.load(model_dir, map_location=torch.device("cpu"))
1162
- torch.save(replace_keys_in_dict(replace_keys_in_dict(model, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), pth_file_old_version_path)
1163
-
1164
- os.remove(model_dir)
1165
- os.rename(pth_file_old_version_path, model_dir)
1166
-
1167
- except Exception as e:
1168
- logger.error(f"{translations['extract_model_error']}: {e}")
1169
-
1170
-
1171
- def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author):
1172
- global global_step
1173
-
1174
- if rank == 0:
1175
- writer = SummaryWriter(log_dir=experiment_dir)
1176
- writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
1177
-
1178
- dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank)
1179
- torch.manual_seed(config.train.seed)
1180
-
1181
- if torch.cuda.is_available(): torch.cuda.set_device(rank)
1182
-
1183
- train_dataset = TextAudioLoaderMultiNSFsid(config.data)
1184
-
1185
- train_sampler = DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True)
1186
-
1187
- collate_fn = TextAudioCollateMultiNSFsid()
1188
-
1189
- train_loader = DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=8)
1190
-
1191
- net_g = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance == True, is_half=config.train.fp16_run and device.type == "cuda", sr=sample_rate).to(device)
1192
-
1193
- if torch.cuda.is_available(): net_g = net_g.cuda(rank)
1194
-
1195
- if version == "v1": net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm)
1196
- else: net_d = MultiPeriodDiscriminatorV2(config.model.use_spectral_norm)
1197
-
1198
- if torch.cuda.is_available(): net_d = net_d.cuda(rank)
1199
-
1200
- optim_g = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
1201
- optim_d = torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
1202
-
1203
- if torch.cuda.is_available():
1204
- net_g = DDP(net_g, device_ids=[rank])
1205
- net_d = DDP(net_d, device_ids=[rank])
1206
- else:
1207
- net_g = DDP(net_g)
1208
- net_d = DDP(net_d)
1209
-
1210
- try:
1211
- logger.info(translations["start_training"])
1212
-
1213
- _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d)
1214
- _, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g)
1215
-
1216
- epoch_str += 1
1217
- global_step = (epoch_str - 1) * len(train_loader)
1218
-
1219
- except:
1220
- epoch_str = 1
1221
- global_step = 0
1222
-
1223
- if pretrainG != "":
1224
- if rank == 0: logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
1225
-
1226
- if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
1227
- else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
1228
- else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
1229
-
1230
- if pretrainD != "":
1231
- if rank == 0: logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
1232
-
1233
- if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
1234
- else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
1235
- else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
1236
-
1237
- scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
1238
- scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
1239
-
1240
- optim_d.step()
1241
- optim_g.step()
1242
-
1243
- scaler = GradScaler(enabled=config.train.fp16_run)
1244
-
1245
- cache = []
1246
-
1247
- for info in train_loader:
1248
- phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
1249
- reference = (
1250
- phone.to(device),
1251
- phone_lengths.to(device),
1252
- pitch.to(device) if pitch_guidance else None,
1253
- pitchf.to(device) if pitch_guidance else None,
1254
- sid.to(device),
1255
- )
1256
- break
1257
-
1258
- for epoch in range(epoch_str, total_epoch + 1):
1259
- if rank == 0: train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, [train_loader, None], [writer, writer_eval], cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author)
1260
- else: train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, [train_loader, None], None, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author)
1261
-
1262
- scheduler_g.step()
1263
- scheduler_d.step()
1264
-
1265
-
1266
- def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, loaders, writers, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author):
1267
- global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
1268
-
1269
- if epoch == 1:
1270
- lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
1271
- last_loss_gen_all = 0.0
1272
- consecutive_increases_gen = 0
1273
- consecutive_increases_disc = 0
1274
-
1275
- net_g, net_d = nets
1276
- optim_g, optim_d = optims
1277
- train_loader = loaders[0] if loaders is not None else None
1278
-
1279
- if writers is not None: writer = writers[0]
1280
-
1281
- train_loader.batch_sampler.set_epoch(epoch)
1282
-
1283
- net_g.train()
1284
- net_d.train()
1285
-
1286
- if device.type == "cuda" and cache_data_in_gpu:
1287
- data_iterator = cache
1288
-
1289
- if cache == []:
1290
- for batch_idx, info in enumerate(train_loader):
1291
- (
1292
- phone,
1293
- phone_lengths,
1294
- pitch,
1295
- pitchf,
1296
- spec,
1297
- spec_lengths,
1298
- wave,
1299
- wave_lengths,
1300
- sid,
1301
- ) = info
1302
- cache.append(
1303
- (batch_idx, (
1304
- phone.cuda(rank, non_blocking=True),
1305
- phone_lengths.cuda(rank, non_blocking=True),
1306
- (pitch.cuda(rank, non_blocking=True) if pitch_guidance else None),
1307
- (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None),
1308
- spec.cuda(rank, non_blocking=True),
1309
- spec_lengths.cuda(rank, non_blocking=True),
1310
- wave.cuda(rank, non_blocking=True),
1311
- wave_lengths.cuda(rank, non_blocking=True),
1312
- sid.cuda(rank, non_blocking=True),
1313
- ),
1314
- ))
1315
- else: shuffle(cache)
1316
- else: data_iterator = enumerate(train_loader)
1317
-
1318
- epoch_recorder = EpochRecorder()
1319
-
1320
- with tqdm(total=len(train_loader), leave=False) as pbar:
1321
- for batch_idx, info in data_iterator:
1322
- (
1323
- phone,
1324
- phone_lengths,
1325
- pitch,
1326
- pitchf,
1327
- spec,
1328
- spec_lengths,
1329
- wave,
1330
- wave_lengths,
1331
- sid,
1332
- ) = info
1333
- if device.type == "cuda" and not cache_data_in_gpu:
1334
- phone = phone.cuda(rank, non_blocking=True)
1335
- phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
1336
- pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None
1337
- pitchf = (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None)
1338
- sid = sid.cuda(rank, non_blocking=True)
1339
- spec = spec.cuda(rank, non_blocking=True)
1340
- spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
1341
- wave = wave.cuda(rank, non_blocking=True)
1342
- wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
1343
- else:
1344
- phone = phone.to(device)
1345
- phone_lengths = phone_lengths.to(device)
1346
- pitch = pitch.to(device) if pitch_guidance else None
1347
- pitchf = pitchf.to(device) if pitch_guidance else None
1348
- sid = sid.to(device)
1349
- spec = spec.to(device)
1350
- spec_lengths = spec_lengths.to(device)
1351
- wave = wave.to(device)
1352
- wave_lengths = wave_lengths.to(device)
1353
-
1354
- use_amp = config.train.fp16_run and device.type == "cuda"
1355
-
1356
- with autocast(enabled=use_amp):
1357
- (
1358
- y_hat,
1359
- ids_slice,
1360
- x_mask,
1361
- z_mask,
1362
- (z, z_p, m_p, logs_p, m_q, logs_q),
1363
- ) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
1364
- mel = spec_to_mel_torch(
1365
- spec,
1366
- config.data.filter_length,
1367
- config.data.n_mel_channels,
1368
- config.data.sample_rate,
1369
- config.data.mel_fmin,
1370
- config.data.mel_fmax,
1371
- )
1372
- y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
1373
- with autocast(enabled=False):
1374
- y_hat_mel = mel_spectrogram_torch(
1375
- y_hat.float().squeeze(1),
1376
- config.data.filter_length,
1377
- config.data.n_mel_channels,
1378
- config.data.sample_rate,
1379
- config.data.hop_length,
1380
- config.data.win_length,
1381
- config.data.mel_fmin,
1382
- config.data.mel_fmax,
1383
- )
1384
- if use_amp: y_hat_mel = y_hat_mel.half()
1385
-
1386
- wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
1387
- y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
1388
-
1389
- with autocast(enabled=False):
1390
- loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
1391
-
1392
- optim_d.zero_grad()
1393
- scaler.scale(loss_disc).backward()
1394
- scaler.unscale_(optim_d)
1395
- grad_norm_d = clip_grad_value(net_d.parameters(), None)
1396
- scaler.step(optim_d)
1397
-
1398
- with autocast(enabled=use_amp):
1399
- y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
1400
- with autocast(enabled=False):
1401
- loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
1402
- loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
1403
- loss_fm = feature_loss(fmap_r, fmap_g)
1404
- loss_gen, losses_gen = generator_loss(y_d_hat_g)
1405
- loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
1406
-
1407
- if loss_gen_all < lowest_value["value"]:
1408
- lowest_value["value"] = loss_gen_all
1409
- lowest_value["step"] = global_step
1410
- lowest_value["epoch"] = epoch
1411
-
1412
- if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
1413
-
1414
- optim_g.zero_grad()
1415
- scaler.scale(loss_gen_all).backward()
1416
- scaler.unscale_(optim_g)
1417
- grad_norm_g = clip_grad_value(net_g.parameters(), None)
1418
- scaler.step(optim_g)
1419
- scaler.update()
1420
-
1421
- if rank == 0:
1422
- if global_step % config.train.log_interval == 0:
1423
- lr = optim_g.param_groups[0]["lr"]
1424
-
1425
- if loss_mel > 75: loss_mel = 75
1426
- if loss_kl > 9: loss_kl = 9
1427
-
1428
- scalar_dict = {
1429
- "loss/g/total": loss_gen_all,
1430
- "loss/d/total": loss_disc,
1431
- "learning_rate": lr,
1432
- "grad_norm_d": grad_norm_d,
1433
- "grad_norm_g": grad_norm_g,
1434
- }
1435
- scalar_dict.update(
1436
- {
1437
- "loss/g/fm": loss_fm,
1438
- "loss/g/mel": loss_mel,
1439
- "loss/g/kl": loss_kl,
1440
- }
1441
- )
1442
- scalar_dict.update(
1443
- {f"loss/g/{i}": v for i, v in enumerate(losses_gen)}
1444
- )
1445
- scalar_dict.update(
1446
- {f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}
1447
- )
1448
- scalar_dict.update(
1449
- {f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}
1450
- )
1451
- image_dict = {
1452
- "slice/mel_org": plot_spectrogram_to_numpy(
1453
- y_mel[0].data.cpu().numpy()
1454
- ),
1455
- "slice/mel_gen": plot_spectrogram_to_numpy(
1456
- y_hat_mel[0].data.cpu().numpy()
1457
- ),
1458
- "all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
1459
- }
1460
-
1461
- with torch.no_grad():
1462
- if hasattr(net_g, "module"): o, *_ = net_g.module.infer(*reference)
1463
- else: o, *_ = net_g.infer(*reference)
1464
-
1465
-
1466
- audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}
1467
-
1468
- summarize(
1469
- writer=writer,
1470
- global_step=global_step,
1471
- images=image_dict,
1472
- scalars=scalar_dict,
1473
- audios=audio_dict,
1474
- audio_sample_rate=config.data.sample_rate,
1475
- )
1476
-
1477
- global_step += 1
1478
- pbar.update(1)
1479
-
1480
- def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
1481
- if len(smoothed_loss_history) < threshold + 1: return False
1482
-
1483
- for i in range(-threshold, -1):
1484
- if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
1485
- if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
1486
-
1487
- return True
1488
-
1489
- def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
1490
- smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
1491
- smoothed_loss_history.append(smoothed_value)
1492
- return smoothed_value
1493
-
1494
- def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
1495
- data = {
1496
- "loss_disc_history": loss_disc_history,
1497
- "smoothed_loss_disc_history": smoothed_loss_disc_history,
1498
- "loss_gen_history": loss_gen_history,
1499
- "smoothed_loss_gen_history": smoothed_loss_gen_history,
1500
- }
1501
-
1502
- with open(file_path, "w") as f:
1503
- json.dump(data, f)
1504
-
1505
- model_add = []
1506
- model_del = []
1507
- done = False
1508
-
1509
- if rank == 0:
1510
- if epoch % save_every_epoch == False:
1511
- checkpoint_suffix = f"{2333333 if save_only_latest else global_step}.pth"
1512
-
1513
- save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
1514
- save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
1515
-
1516
- if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
1517
-
1518
- if overtraining_detector and epoch > 1:
1519
- current_loss_disc = float(loss_disc)
1520
- loss_disc_history.append(current_loss_disc)
1521
-
1522
- smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
1523
- is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
1524
-
1525
- if is_overtraining_disc: consecutive_increases_disc += 1
1526
- else: consecutive_increases_disc = 0
1527
-
1528
- current_loss_gen = float(lowest_value["value"])
1529
- loss_gen_history.append(current_loss_gen)
1530
-
1531
- smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
1532
- is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
1533
-
1534
- if is_overtraining_gen: consecutive_increases_gen += 1
1535
- else: consecutive_increases_gen = 0
1536
-
1537
- if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
1538
-
1539
- if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
1540
- logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
1541
- done = True
1542
- else:
1543
- logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
1544
-
1545
- old_model_files = glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth"))
1546
-
1547
- for file in old_model_files:
1548
- model_del.append(file)
1549
-
1550
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
1551
-
1552
- if epoch >= custom_total_epoch:
1553
- lowest_value_rounded = float(lowest_value["value"])
1554
- lowest_value_rounded = round(lowest_value_rounded, 3)
1555
-
1556
- logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
1557
- logger.info(translations["training_info"].format(lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
1558
-
1559
- pid_file_path = os.path.join(experiment_dir, "config.json")
1560
-
1561
- with open(pid_file_path, "r") as pid_file:
1562
- pid_data = json.load(pid_file)
1563
- with open(pid_file_path, "w") as pid_file:
1564
- pid_data.pop("process_pids", None)
1565
- json.dump(pid_data, pid_file, indent=4)
1566
-
1567
- model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
1568
- done = True
1569
-
1570
- if model_add:
1571
- ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
1572
-
1573
- for m in model_add:
1574
- if not os.path.exists(m): extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_dir=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author)
1575
-
1576
- for m in model_del:
1577
- os.remove(m)
1578
-
1579
- lowest_value_rounded = float(lowest_value["value"])
1580
- lowest_value_rounded = round(lowest_value_rounded, 3)
1581
-
1582
- if epoch > 1 and overtraining_detector:
1583
- remaining_epochs_gen = overtraining_threshold - consecutive_increases_gen
1584
- remaining_epochs_disc = (overtraining_threshold * 2) - consecutive_increases_disc
1585
-
1586
- logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=remaining_epochs_gen, remaining_epochs_disc=remaining_epochs_disc, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
1587
- elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
1588
- else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
1589
-
1590
- last_loss_gen_all = loss_gen_all
1591
-
1592
- if done: os._exit(2333333)
1593
-
1594
- if __name__ == "__main__":
1595
- torch.multiprocessing.set_start_method("spawn")
1596
-
1597
- try:
1598
- main()
1599
- except Exception as e:
1600
- logger.error(f"{translations['training_error']} {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/commons.py DELETED
@@ -1,100 +0,0 @@
1
- import math
2
- import torch
3
-
4
- from typing import List, Optional
5
-
6
- def init_weights(m, mean=0.0, std=0.01):
7
- classname = m.__class__.__name__
8
- if classname.find("Conv") != -1: m.weight.data.normal_(mean, std)
9
-
10
- def get_padding(kernel_size, dilation=1):
11
- return int((kernel_size * dilation - dilation) / 2)
12
-
13
- def convert_pad_shape(pad_shape):
14
- l = pad_shape[::-1]
15
- pad_shape = [item for sublist in l for item in sublist]
16
- return pad_shape
17
-
18
- def kl_divergence(m_p, logs_p, m_q, logs_q):
19
- kl = (logs_q - logs_p) - 0.5
20
- kl += (0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q))
21
- return kl
22
-
23
- def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2):
24
- if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
25
- elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
26
-
27
- for i in range(x.size(0)):
28
- idx_str = ids_str[i].item()
29
- idx_end = idx_str + segment_size
30
-
31
- if dim == 2: ret[i] = x[i, idx_str:idx_end]
32
- else: ret[i] = x[i, :, idx_str:idx_end]
33
-
34
- return ret
35
-
36
- def rand_slice_segments(x, x_lengths=None, segment_size=4):
37
- b, d, t = x.size()
38
-
39
- if x_lengths is None: x_lengths = t
40
-
41
- ids_str_max = x_lengths - segment_size + 1
42
- ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
43
- ret = slice_segments(x, ids_str, segment_size, dim=3)
44
- return ret, ids_str
45
-
46
- def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
47
- position = torch.arange(length, dtype=torch.float)
48
- num_timescales = channels // 2
49
- log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
50
- inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
51
- scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
52
- signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
53
- signal = torch.nn.functional.pad(signal, [0, 0, 0, channels % 2])
54
- signal = signal.view(1, channels, length)
55
- return signal
56
-
57
-
58
- def subsequent_mask(length):
59
- mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
60
- return mask
61
-
62
-
63
- @torch.jit.script
64
- def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
65
- n_channels_int = n_channels[0]
66
- in_act = input_a + input_b
67
- t_act = torch.tanh(in_act[:, :n_channels_int, :])
68
- s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
69
- acts = t_act * s_act
70
- return acts
71
-
72
-
73
- def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
74
- return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
75
-
76
-
77
- def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
78
- if max_length is None: max_length = length.max()
79
-
80
- x = torch.arange(max_length, dtype=length.dtype, device=length.device)
81
- return x.unsqueeze(0) < length.unsqueeze(1)
82
-
83
-
84
- def clip_grad_value(parameters, clip_value, norm_type=2):
85
- if isinstance(parameters, torch.Tensor): parameters = [parameters]
86
-
87
- parameters = list(filter(lambda p: p.grad is not None, parameters))
88
- norm_type = float(norm_type)
89
-
90
- if clip_value is not None: clip_value = float(clip_value)
91
-
92
- total_norm = 0
93
-
94
- for p in parameters:
95
- param_norm = p.grad.data.norm(norm_type)
96
- total_norm += param_norm.item() ** norm_type
97
- if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
98
-
99
- total_norm = total_norm ** (1.0 / norm_type)
100
- return total_norm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/modules.py DELETED
@@ -1,80 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- now_dir = os.getcwd()
6
- sys.path.append(now_dir)
7
-
8
- from .commons import fused_add_tanh_sigmoid_multiply
9
-
10
- class WaveNet(torch.nn.Module):
11
- def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
12
- super(WaveNet, self).__init__()
13
- assert kernel_size % 2 == 1
14
-
15
- self.hidden_channels = hidden_channels
16
- self.kernel_size = (kernel_size,)
17
- self.dilation_rate = dilation_rate
18
-
19
- self.n_layers = n_layers
20
- self.gin_channels = gin_channels
21
- self.p_dropout = p_dropout
22
-
23
- self.in_layers = torch.nn.ModuleList()
24
- self.res_skip_layers = torch.nn.ModuleList()
25
- self.drop = torch.nn.Dropout(p_dropout)
26
-
27
- if gin_channels != 0:
28
- cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
29
- self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
30
-
31
- dilations = [dilation_rate**i for i in range(n_layers)]
32
- paddings = [(kernel_size * d - d) // 2 for d in dilations]
33
-
34
- for i in range(n_layers):
35
- in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
36
- in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
37
-
38
- self.in_layers.append(in_layer)
39
-
40
- res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
41
- res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
42
- res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
43
-
44
- self.res_skip_layers.append(res_skip_layer)
45
-
46
- def forward(self, x, x_mask, g=None, **kwargs):
47
- output = torch.zeros_like(x)
48
- n_channels_tensor = torch.IntTensor([self.hidden_channels])
49
-
50
- if g is not None: g = self.cond_layer(g)
51
-
52
- for i in range(self.n_layers):
53
- x_in = self.in_layers[i](x)
54
-
55
- if g is not None:
56
- cond_offset = i * 2 * self.hidden_channels
57
- g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
58
- else: g_l = torch.zeros_like(x_in)
59
-
60
- acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
61
- acts = self.drop(acts)
62
-
63
- res_skip_acts = self.res_skip_layers[i](acts)
64
-
65
- if i < self.n_layers - 1:
66
- res_acts = res_skip_acts[:, : self.hidden_channels, :]
67
- x = (x + res_acts) * x_mask
68
- output = output + res_skip_acts[:, self.hidden_channels :, :]
69
- else: output = output + res_skip_acts
70
-
71
- return output * x_mask
72
-
73
- def remove_weight_norm(self):
74
- if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
75
-
76
- for l in self.in_layers:
77
- torch.nn.utils.remove_weight_norm(l)
78
-
79
- for l in self.res_skip_layers:
80
- torch.nn.utils.remove_weight_norm(l)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/residuals.py DELETED
@@ -1,170 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
-
5
- from typing import Optional
6
-
7
- from torch.nn.utils import remove_weight_norm
8
- from torch.nn.utils.parametrizations import weight_norm
9
-
10
- now_dir = os.getcwd()
11
- sys.path.append(now_dir)
12
-
13
- from .modules import WaveNet
14
- from .commons import get_padding, init_weights
15
-
16
-
17
- LRELU_SLOPE = 0.1
18
-
19
- def create_conv1d_layer(channels, kernel_size, dilation):
20
- return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
21
-
22
- def apply_mask(tensor, mask):
23
- return tensor * mask if mask is not None else tensor
24
-
25
- class ResBlockBase(torch.nn.Module):
26
- def __init__(self, channels, kernel_size, dilations):
27
- super(ResBlockBase, self).__init__()
28
-
29
- self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
30
- self.convs1.apply(init_weights)
31
-
32
- self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
33
- self.convs2.apply(init_weights)
34
-
35
- def forward(self, x, x_mask=None):
36
- for c1, c2 in zip(self.convs1, self.convs2):
37
- xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
38
- xt = apply_mask(xt, x_mask)
39
- xt = torch.nn.functional.leaky_relu(c1(xt), LRELU_SLOPE)
40
- xt = apply_mask(xt, x_mask)
41
- xt = c2(xt)
42
- x = xt + x
43
- return apply_mask(x, x_mask)
44
-
45
- def remove_weight_norm(self):
46
- for conv in self.convs1 + self.convs2:
47
- remove_weight_norm(conv)
48
-
49
- class ResBlock1(ResBlockBase):
50
- def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
51
- super(ResBlock1, self).__init__(channels, kernel_size, dilation)
52
-
53
- class ResBlock2(ResBlockBase):
54
- def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
55
- super(ResBlock2, self).__init__(channels, kernel_size, dilation)
56
-
57
- class Log(torch.nn.Module):
58
- def forward(self, x, x_mask, reverse=False, **kwargs):
59
- if not reverse:
60
- y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
61
- logdet = torch.sum(-y, [1, 2])
62
- return y, logdet
63
- else:
64
- x = torch.exp(x) * x_mask
65
- return x
66
-
67
- class Flip(torch.nn.Module):
68
- def forward(self, x, *args, reverse=False, **kwargs):
69
- x = torch.flip(x, [1])
70
- if not reverse:
71
- logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
72
- return x, logdet
73
- else: return x
74
-
75
- class ElementwiseAffine(torch.nn.Module):
76
- def __init__(self, channels):
77
- super().__init__()
78
- self.channels = channels
79
- self.m = torch.nn.Parameter(torch.zeros(channels, 1))
80
- self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
81
-
82
- def forward(self, x, x_mask, reverse=False, **kwargs):
83
- if not reverse:
84
- y = self.m + torch.exp(self.logs) * x
85
- y = y * x_mask
86
- logdet = torch.sum(self.logs * x_mask, [1, 2])
87
- return y, logdet
88
- else:
89
- x = (x - self.m) * torch.exp(-self.logs) * x_mask
90
- return x
91
-
92
-
93
- class ResidualCouplingBlock(torch.nn.Module):
94
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
95
- super(ResidualCouplingBlock, self).__init__()
96
- self.channels = channels
97
- self.hidden_channels = hidden_channels
98
- self.kernel_size = kernel_size
99
- self.dilation_rate = dilation_rate
100
- self.n_layers = n_layers
101
- self.n_flows = n_flows
102
- self.gin_channels = gin_channels
103
-
104
- self.flows = torch.nn.ModuleList()
105
- for i in range(n_flows):
106
- self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
107
- self.flows.append(Flip())
108
-
109
- def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse = False):
110
- if not reverse:
111
- for flow in self.flows:
112
- x, _ = flow(x, x_mask, g=g, reverse=reverse)
113
- else:
114
- for flow in reversed(self.flows):
115
- x = flow.forward(x, x_mask, g=g, reverse=reverse)
116
-
117
- return x
118
-
119
- def remove_weight_norm(self):
120
- for i in range(self.n_flows):
121
- self.flows[i * 2].remove_weight_norm()
122
-
123
- def __prepare_scriptable__(self):
124
- for i in range(self.n_flows):
125
- for hook in self.flows[i * 2]._forward_pre_hooks.values():
126
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
127
- return self
128
-
129
-
130
- class ResidualCouplingLayer(torch.nn.Module):
131
- def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
132
- assert channels % 2 == 0, "Channels/2"
133
- super().__init__()
134
- self.channels = channels
135
- self.hidden_channels = hidden_channels
136
- self.kernel_size = kernel_size
137
- self.dilation_rate = dilation_rate
138
- self.n_layers = n_layers
139
- self.half_channels = channels // 2
140
- self.mean_only = mean_only
141
-
142
- self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
143
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
144
- self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
145
- self.post.weight.data.zero_()
146
- self.post.bias.data.zero_()
147
-
148
- def forward(self, x, x_mask, g=None, reverse=False):
149
- x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
150
- h = self.pre(x0) * x_mask
151
- h = self.enc(h, x_mask, g=g)
152
- stats = self.post(h) * x_mask
153
-
154
- if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
155
- else:
156
- m = stats
157
- logs = torch.zeros_like(m)
158
-
159
- if not reverse:
160
- x1 = m + x1 * torch.exp(logs) * x_mask
161
- x = torch.cat([x0, x1], 1)
162
- logdet = torch.sum(logs, [1, 2])
163
- return x, logdet
164
- else:
165
- x1 = (x1 - m) * torch.exp(-logs) * x_mask
166
- x = torch.cat([x0, x1], 1)
167
- return x
168
-
169
- def remove_weight_norm(self):
170
- self.enc.remove_weight_norm()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/separator.py DELETED
@@ -1,420 +0,0 @@
1
- import os
2
- import sys
3
- import time
4
- import json
5
- import yaml
6
- import torch
7
- import codecs
8
- import hashlib
9
- import logging
10
- import platform
11
- import warnings
12
- import requests
13
- import subprocess
14
-
15
- import onnxruntime as ort
16
-
17
- from tqdm import tqdm
18
- from importlib import metadata, import_module
19
-
20
- now_dir = os.getcwd()
21
- sys.path.append(now_dir)
22
-
23
- from main.configs.config import Config
24
- translations = Config().translations
25
-
26
- class Separator:
27
- def __init__(self, log_level=logging.INFO, log_formatter=None, model_file_dir="assets/model/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
28
- self.logger = logging.getLogger(__name__)
29
- self.logger.setLevel(log_level)
30
- self.log_level = log_level
31
- self.log_formatter = log_formatter
32
- self.log_handler = logging.StreamHandler()
33
-
34
- if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
35
-
36
- self.log_handler.setFormatter(self.log_formatter)
37
-
38
- if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
39
- if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
40
-
41
- self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
42
-
43
- self.model_file_dir = model_file_dir
44
-
45
- if output_dir is None:
46
- output_dir = os.getcwd()
47
- self.logger.info(translations["output_dir_is_none"])
48
-
49
- self.output_dir = output_dir
50
-
51
- os.makedirs(self.model_file_dir, exist_ok=True)
52
- os.makedirs(self.output_dir, exist_ok=True)
53
-
54
- self.output_format = output_format
55
- self.output_bitrate = output_bitrate
56
-
57
- if self.output_format is None: self.output_format = "wav"
58
-
59
- self.normalization_threshold = normalization_threshold
60
-
61
- if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
62
-
63
- self.output_single_stem = output_single_stem
64
-
65
- if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
66
-
67
- self.invert_using_spec = invert_using_spec
68
- if self.invert_using_spec: self.logger.debug(translations["step2"])
69
-
70
- try:
71
- self.sample_rate = int(sample_rate)
72
-
73
- if self.sample_rate <= 0: raise ValueError(translations["other_than_zero"].format(sample_rate=self.sample_rate))
74
- if self.sample_rate > 12800000: raise ValueError(translations["too_high"].format(sample_rate=self.sample_rate))
75
- except ValueError:
76
- raise ValueError(translations["sr_not_valid"])
77
-
78
- self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
79
- self.torch_device = None
80
- self.torch_device_cpu = None
81
- self.torch_device_mps = None
82
- self.onnx_execution_provider = None
83
- self.model_instance = None
84
- self.model_is_uvr_vip = False
85
- self.model_friendly_name = None
86
-
87
- self.setup_accelerated_inferencing_device()
88
-
89
- def setup_accelerated_inferencing_device(self):
90
- system_info = self.get_system_info()
91
- self.check_ffmpeg_installed()
92
- self.log_onnxruntime_packages()
93
- self.setup_torch_device(system_info)
94
-
95
- def get_system_info(self):
96
- os_name = platform.system()
97
- os_version = platform.version()
98
-
99
- self.logger.info(f"{translations['os']}: {os_name} {os_version}")
100
-
101
- system_info = platform.uname()
102
- self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
103
-
104
- python_version = platform.python_version()
105
- self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
106
-
107
- pytorch_version = torch.__version__
108
- self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
109
-
110
- return system_info
111
-
112
- def check_ffmpeg_installed(self):
113
- try:
114
- ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
115
- first_line = ffmpeg_version_output.splitlines()[0]
116
- self.logger.info(f"{translations['install_ffmpeg']}: {first_line}")
117
- except FileNotFoundError:
118
- self.logger.error(translations["none_ffmpeg"])
119
- if "PYTEST_CURRENT_TEST" not in os.environ: raise
120
-
121
- def log_onnxruntime_packages(self):
122
- onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
123
- onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
124
-
125
- if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
126
- if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
127
-
128
- def setup_torch_device(self, system_info):
129
- hardware_acceleration_enabled = False
130
- ort_providers = ort.get_available_providers()
131
-
132
- self.torch_device_cpu = torch.device("cpu")
133
-
134
- if torch.cuda.is_available():
135
- self.configure_cuda(ort_providers)
136
- hardware_acceleration_enabled = True
137
- elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
138
- self.configure_mps(ort_providers)
139
- hardware_acceleration_enabled = True
140
-
141
- if not hardware_acceleration_enabled:
142
- self.logger.info(translations["running_in_cpu"])
143
- self.torch_device = self.torch_device_cpu
144
- self.onnx_execution_provider = ["CPUExecutionProvider"]
145
-
146
- def configure_cuda(self, ort_providers):
147
- self.logger.info(translations["running_in_cuda"])
148
- self.torch_device = torch.device("cuda")
149
-
150
- if "CUDAExecutionProvider" in ort_providers:
151
- self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
152
- self.onnx_execution_provider = ["CUDAExecutionProvider"]
153
- else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
154
-
155
- def configure_mps(self, ort_providers):
156
- self.logger.info("Cài đặt thiết bị Torch thành MPS")
157
- self.torch_device_mps = torch.device("mps")
158
- self.torch_device = self.torch_device_mps
159
-
160
- if "CoreMLExecutionProvider" in ort_providers:
161
- self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
162
- self.onnx_execution_provider = ["CoreMLExecutionProvider"]
163
- else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
164
-
165
- def get_package_distribution(self, package_name):
166
- try:
167
- return metadata.distribution(package_name)
168
- except metadata.PackageNotFoundError:
169
- self.logger.debug(translations["python_not_install"].format(package_name=package_name))
170
- return None
171
-
172
- def get_model_hash(self, model_path):
173
- self.logger.debug(translations["hash"].format(model_path=model_path))
174
-
175
- try:
176
- with open(model_path, "rb") as f:
177
- f.seek(-10000 * 1024, 2)
178
- return hashlib.md5(f.read()).hexdigest()
179
- except IOError as e:
180
- self.logger.error(translations["ioerror"].format(e=e))
181
-
182
- return hashlib.md5(open(model_path, "rb").read()).hexdigest()
183
-
184
- def download_file_if_not_exists(self, url, output_path):
185
- if os.path.isfile(output_path):
186
- self.logger.debug(translations["cancel_download"].format(output_path=output_path))
187
- return
188
-
189
- self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
190
- response = requests.get(url, stream=True, timeout=300)
191
-
192
- if response.status_code == 200:
193
- total_size_in_bytes = int(response.headers.get("content-length", 0))
194
- progress_bar = tqdm(total=total_size_in_bytes)
195
-
196
- with open(output_path, "wb") as f:
197
- for chunk in response.iter_content(chunk_size=8192):
198
- progress_bar.update(len(chunk))
199
- f.write(chunk)
200
-
201
- progress_bar.close()
202
- else: raise RuntimeError(translations["download_error"].format(url=url, status_code=response.status_code))
203
-
204
- def print_uvr_vip_message(self):
205
- if self.model_is_uvr_vip:
206
- self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
207
- self.logger.warning(translations["vip_print"])
208
-
209
- def list_supported_model_files(self):
210
- download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
211
-
212
- model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
213
- self.logger.debug(translations["load_download_json"])
214
-
215
- filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
216
-
217
- model_files_grouped_by_type = {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": filtered_demucs_v4}
218
- return model_files_grouped_by_type
219
-
220
- def download_model_files(self, model_filename):
221
- model_path = os.path.join(self.model_file_dir, f"{model_filename}")
222
-
223
- supported_model_files_grouped = self.list_supported_model_files()
224
- public_model_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/GEiyie/zbqry_ercb/eryrnfrf/qbjaybnq/nyy_choyvp_hie_zbqryf", "rot13")
225
- vip_model_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/Nawbx0109/nv_zntvp/eryrnfrf/qbjaybnq/i5", "rot13")
226
-
227
- audio_separator_models_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/abznqxnenbxr/clguba-nhqvb-frcnengbe/eryrnfrf/qbjaybnq/zbqry-pbasvtf", "rot13")
228
-
229
- yaml_config_filename = None
230
-
231
- self.logger.debug(translations["search_model"].format(model_filename=model_filename))
232
-
233
- for model_type, model_list in supported_model_files_grouped.items():
234
- for model_friendly_name, model_download_list in model_list.items():
235
- self.model_is_uvr_vip = "VIP" in model_friendly_name
236
- model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
237
-
238
- if isinstance(model_download_list, str) and model_download_list == model_filename:
239
- self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
240
- self.model_friendly_name = model_friendly_name
241
-
242
- try:
243
- self.download_file_if_not_exists(f"{model_repo_url_prefix}/{model_filename}", model_path)
244
- except RuntimeError:
245
- self.logger.debug(translations["not_found_model"])
246
- self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{model_filename}", model_path)
247
-
248
- self.print_uvr_vip_message()
249
-
250
- self.logger.debug(translations["single_model_path"].format(model_path=model_path))
251
-
252
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
253
- elif isinstance(model_download_list, dict):
254
- this_model_matches_input_filename = False
255
-
256
- for file_name, file_url in model_download_list.items():
257
- if file_name == model_filename or file_url == model_filename:
258
- self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
259
- this_model_matches_input_filename = True
260
-
261
- if this_model_matches_input_filename:
262
- self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
263
- self.model_friendly_name = model_friendly_name
264
- self.print_uvr_vip_message()
265
-
266
- for config_key, config_value in model_download_list.items():
267
- self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
268
-
269
- if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
270
- elif config_key.endswith(".ckpt"):
271
- try:
272
- download_url = f"{model_repo_url_prefix}/{config_key}"
273
- self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
274
- except RuntimeError:
275
- self.logger.debug(translations["not_found_model_warehouse"])
276
- download_url = f"{audio_separator_models_repo_url_prefix}/{config_key}"
277
- self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
278
-
279
- if model_filename.endswith(".yaml"):
280
- self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
281
- self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
282
- self.logger.warning(translations["yaml_warning_3"])
283
-
284
- model_filename = config_key
285
- model_path = os.path.join(self.model_file_dir, f"{model_filename}")
286
-
287
- yaml_config_filename = config_value
288
- yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
289
-
290
- try:
291
- url = codecs.decode("uggcf://enj.tvguhohfrepbagrag.pbz/GEiyie/nccyvpngvba_qngn/znva/zqk_zbqry_qngn/zqk_p_pbasvtf", "rot13")
292
- yaml_config_url = f"{url}/{yaml_config_filename}"
293
- self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
294
- except RuntimeError:
295
- self.logger.debug(translations["yaml_debug"])
296
- yaml_config_url = f"{audio_separator_models_repo_url_prefix}/{yaml_config_filename}"
297
- self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
298
- else:
299
- download_url = f"{model_repo_url_prefix}/{config_value}"
300
- self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_value))
301
-
302
- self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
303
-
304
- return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
305
-
306
- raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
307
-
308
- def load_model_data_from_yaml(self, yaml_config_filename):
309
- model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
310
-
311
- self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
312
-
313
- model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
314
- self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
315
-
316
- if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
317
-
318
- return model_data
319
-
320
- def load_model_data_using_hash(self, model_path):
321
- mdx_model_data_url = codecs.decode("uggcf://enj.tvguhohfrepbagrag.pbz/GEiyie/nccyvpngvba_qngn/znva/zqk_zbqry_qngn/zbqry_qngn_arj.wfba", "rot13")
322
-
323
- self.logger.debug(translations["hash_md5"])
324
- model_hash = self.get_model_hash(model_path)
325
- self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
326
-
327
- mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
328
- self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
329
- self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
330
-
331
- self.logger.debug(translations["load_mdx"])
332
- mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
333
-
334
- if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
335
- else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
336
-
337
- self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
338
-
339
- return model_data
340
-
341
- def load_model(self, model_filename):
342
- self.logger.info(translations["loading_model"].format(model_filename=model_filename))
343
-
344
- load_model_start_time = time.perf_counter()
345
-
346
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
347
- model_name = model_filename.split(".")[0]
348
- self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
349
-
350
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
351
-
352
- model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path)
353
-
354
- common_params = {
355
- "logger": self.logger,
356
- "log_level": self.log_level,
357
- "torch_device": self.torch_device,
358
- "torch_device_cpu": self.torch_device_cpu,
359
- "torch_device_mps": self.torch_device_mps,
360
- "onnx_execution_provider": self.onnx_execution_provider,
361
- "model_name": model_name,
362
- "model_path": model_path,
363
- "model_data": model_data,
364
- "output_format": self.output_format,
365
- "output_bitrate": self.output_bitrate,
366
- "output_dir": self.output_dir,
367
- "normalization_threshold": self.normalization_threshold,
368
- "output_single_stem": self.output_single_stem,
369
- "invert_using_spec": self.invert_using_spec,
370
- "sample_rate": self.sample_rate,
371
- }
372
-
373
- separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
374
-
375
- if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
376
- if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
377
-
378
- self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
379
-
380
- module_name, class_name = separator_classes[model_type].split(".")
381
- module = import_module(f"main.library.architectures.{module_name}")
382
- separator_class = getattr(module, class_name)
383
-
384
- self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
385
- self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
386
-
387
- self.logger.debug(translations["loading_model_success"])
388
- self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
389
-
390
- def separate(self, audio_file_path):
391
- self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
392
- separate_start_time = time.perf_counter()
393
-
394
- self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
395
-
396
- output_files = self.model_instance.separate(audio_file_path)
397
-
398
- self.model_instance.clear_gpu_cache()
399
-
400
- self.model_instance.clear_file_specific_paths()
401
-
402
- self.print_uvr_vip_message()
403
-
404
- self.logger.debug(translations["separator_success_3"])
405
- self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
406
-
407
- return output_files
408
-
409
- def download_model_and_data(self, model_filename):
410
- self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
411
-
412
- model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
413
-
414
- if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
415
-
416
- model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path)
417
-
418
- model_data_dict_size = len(model_data)
419
-
420
- self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=model_data_dict_size))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/algorithm/synthesizers.py DELETED
@@ -1,590 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
-
6
- from typing import Optional
7
- from torch.nn.utils import remove_weight_norm
8
- from torch.nn.utils.parametrizations import weight_norm
9
-
10
- now_dir = os.getcwd()
11
- sys.path.append(now_dir)
12
-
13
- from .modules import WaveNet
14
- from .residuals import ResidualCouplingBlock, ResBlock1, ResBlock2, LRELU_SLOPE
15
- from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
16
-
17
- class Generator(torch.nn.Module):
18
- def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
19
- super(Generator, self).__init__()
20
-
21
- self.num_kernels = len(resblock_kernel_sizes)
22
- self.num_upsamples = len(upsample_rates)
23
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
24
- resblock = ResBlock1 if resblock == "1" else ResBlock2
25
-
26
- self.ups_and_resblocks = torch.nn.ModuleList()
27
-
28
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
29
- self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
30
-
31
- ch = upsample_initial_channel // (2 ** (i + 1))
32
-
33
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
34
- self.ups_and_resblocks.append(resblock(ch, k, d))
35
-
36
- self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
37
- self.ups_and_resblocks.apply(init_weights)
38
-
39
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
40
-
41
- def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
42
- x = self.conv_pre(x)
43
- if g is not None: x = x + self.cond(g)
44
-
45
- resblock_idx = 0
46
-
47
- for _ in range(self.num_upsamples):
48
- x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
49
- x = self.ups_and_resblocks[resblock_idx](x)
50
- resblock_idx += 1
51
- xs = 0
52
-
53
- for _ in range(self.num_kernels):
54
- xs += self.ups_and_resblocks[resblock_idx](x)
55
- resblock_idx += 1
56
-
57
- x = xs / self.num_kernels
58
-
59
- x = torch.nn.functional.leaky_relu(x)
60
- x = self.conv_post(x)
61
- x = torch.tanh(x)
62
-
63
- return x
64
-
65
- def __prepare_scriptable__(self):
66
- for l in self.ups_and_resblocks:
67
- for hook in l._forward_pre_hooks.values():
68
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
69
-
70
- return self
71
- def remove_weight_norm(self):
72
- for l in self.ups_and_resblocks:
73
- remove_weight_norm(l)
74
-
75
- class SineGen(torch.nn.Module):
76
- def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
77
- super(SineGen, self).__init__()
78
- self.sine_amp = sine_amp
79
- self.noise_std = noise_std
80
- self.harmonic_num = harmonic_num
81
- self.dim = self.harmonic_num + 1
82
- self.sample_rate = samp_rate
83
- self.voiced_threshold = voiced_threshold
84
-
85
- def _f02uv(self, f0):
86
- uv = torch.ones_like(f0)
87
- uv = uv * (f0 > self.voiced_threshold)
88
- return uv
89
-
90
- def forward(self, f0: torch.Tensor, upp: int):
91
- with torch.no_grad():
92
- f0 = f0[:, None].transpose(1, 2)
93
- f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
94
- f0_buf[:, :, 0] = f0[:, :, 0]
95
- f0_buf[:, :, 1:] = (f0_buf[:, :, 0:1] * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :])
96
- rad_values = (f0_buf / float(self.sample_rate)) % 1
97
- rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
98
- rand_ini[:, 0] = 0
99
- rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
100
- tmp_over_one = torch.cumsum(rad_values, 1)
101
- tmp_over_one *= upp
102
- tmp_over_one = torch.nn.functional.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1)
103
- rad_values = torch.nn.functional.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
104
- tmp_over_one %= 1
105
- tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
106
- cumsum_shift = torch.zeros_like(rad_values)
107
- cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
108
- sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi)
109
- sine_waves = sine_waves * self.sine_amp
110
- uv = self._f02uv(f0)
111
- uv = torch.nn.functional.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
112
- noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
113
- noise = noise_amp * torch.randn_like(sine_waves)
114
- sine_waves = sine_waves * uv + noise
115
-
116
- return sine_waves, uv, noise
117
-
118
- class SourceModuleHnNSF(torch.nn.Module):
119
- def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=True):
120
- super(SourceModuleHnNSF, self).__init__()
121
-
122
- self.sine_amp = sine_amp
123
- self.noise_std = add_noise_std
124
- self.is_half = is_half
125
-
126
- self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
127
- self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
128
- self.l_tanh = torch.nn.Tanh()
129
-
130
- def forward(self, x: torch.Tensor, upsample_factor: int = 1):
131
- sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor)
132
- sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
133
- sine_merge = self.l_tanh(self.l_linear(sine_wavs))
134
- return sine_merge, None, None
135
-
136
-
137
- class GeneratorNSF(torch.nn.Module):
138
- def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False):
139
- super(GeneratorNSF, self).__init__()
140
-
141
- self.num_kernels = len(resblock_kernel_sizes)
142
- self.num_upsamples = len(upsample_rates)
143
- self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
144
- self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0, is_half=is_half)
145
-
146
- self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
147
- resblock_cls = ResBlock1 if resblock == "1" else ResBlock2
148
-
149
- self.ups = torch.nn.ModuleList()
150
- self.noise_convs = torch.nn.ModuleList()
151
-
152
- channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates))]
153
- stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
154
-
155
- for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
156
- self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=(k - u) // 2)))
157
- self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1), stride=stride_f0s[i], padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0)))
158
-
159
- self.resblocks = torch.nn.ModuleList([resblock_cls(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
160
-
161
- self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
162
- self.ups.apply(init_weights)
163
-
164
- if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
165
-
166
- self.upp = math.prod(upsample_rates)
167
- self.lrelu_slope = LRELU_SLOPE
168
-
169
- def forward(self, x, f0, g: Optional[torch.Tensor] = None):
170
- har_source, _, _ = self.m_source(f0, self.upp)
171
- har_source = har_source.transpose(1, 2)
172
- x = self.conv_pre(x)
173
-
174
- if g is not None: x = x + self.cond(g)
175
-
176
- for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
177
- x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
178
- x = ups(x)
179
- x = x + noise_convs(har_source)
180
-
181
- xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
182
- x = xs / self.num_kernels
183
-
184
- x = torch.nn.functional.leaky_relu(x)
185
- x = torch.tanh(self.conv_post(x))
186
- return x
187
-
188
- def remove_weight_norm(self):
189
- for l in self.ups:
190
- remove_weight_norm(l)
191
- for l in self.resblocks:
192
- l.remove_weight_norm()
193
-
194
- def __prepare_scriptable__(self):
195
- for l in self.ups:
196
- for hook in l._forward_pre_hooks.values():
197
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): remove_weight_norm(l)
198
-
199
- for l in self.resblocks:
200
- for hook in l._forward_pre_hooks.values():
201
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): remove_weight_norm(l)
202
- return self
203
-
204
- class LayerNorm(torch.nn.Module):
205
- def __init__(self, channels, eps=1e-5):
206
- super().__init__()
207
- self.eps = eps
208
- self.gamma = torch.nn.Parameter(torch.ones(channels))
209
- self.beta = torch.nn.Parameter(torch.zeros(channels))
210
-
211
- def forward(self, x):
212
- x = x.transpose(1, -1)
213
- x = torch.nn.functional.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)
214
- return x.transpose(1, -1)
215
-
216
- class MultiHeadAttention(torch.nn.Module):
217
- def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
218
- super().__init__()
219
- assert channels % n_heads == 0
220
-
221
- self.channels = channels
222
- self.out_channels = out_channels
223
- self.n_heads = n_heads
224
- self.p_dropout = p_dropout
225
- self.window_size = window_size
226
- self.heads_share = heads_share
227
- self.block_length = block_length
228
- self.proximal_bias = proximal_bias
229
- self.proximal_init = proximal_init
230
- self.attn = None
231
- self.k_channels = channels // n_heads
232
- self.conv_q = torch.nn.Conv1d(channels, channels, 1)
233
- self.conv_k = torch.nn.Conv1d(channels, channels, 1)
234
- self.conv_v = torch.nn.Conv1d(channels, channels, 1)
235
- self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
236
- self.drop = torch.nn.Dropout(p_dropout)
237
-
238
- if window_size is not None:
239
- n_heads_rel = 1 if heads_share else n_heads
240
- rel_stddev = self.k_channels**-0.5
241
-
242
- self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
243
- self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
244
-
245
- torch.nn.init.xavier_uniform_(self.conv_q.weight)
246
- torch.nn.init.xavier_uniform_(self.conv_k.weight)
247
- torch.nn.init.xavier_uniform_(self.conv_v.weight)
248
-
249
- if proximal_init:
250
- with torch.no_grad():
251
- self.conv_k.weight.copy_(self.conv_q.weight)
252
- self.conv_k.bias.copy_(self.conv_q.bias)
253
-
254
- def forward(self, x, c, attn_mask=None):
255
- q = self.conv_q(x)
256
- k = self.conv_k(c)
257
- v = self.conv_v(c)
258
-
259
- x, self.attn = self.attention(q, k, v, mask=attn_mask)
260
-
261
- x = self.conv_o(x)
262
- return x
263
-
264
- def attention(self, query, key, value, mask=None):
265
- b, d, t_s, t_t = (*key.size(), query.size(2))
266
- query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
267
- key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
268
- value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
269
-
270
- scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
271
- if self.window_size is not None:
272
- assert (t_s == t_t), "(t_s == t_t)"
273
- key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
274
- rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
275
- scores_local = self._relative_position_to_absolute_position(rel_logits)
276
- scores = scores + scores_local
277
-
278
- if self.proximal_bias:
279
- assert t_s == t_t, "t_s == t_t"
280
- scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
281
-
282
- if mask is not None:
283
- scores = scores.masked_fill(mask == 0, -1e4)
284
- if self.block_length is not None:
285
- assert (t_s == t_t), "(t_s == t_t)"
286
-
287
- block_mask = (torch.ones_like(scores).triu(-self.block_length).tril(self.block_length))
288
- scores = scores.masked_fill(block_mask == 0, -1e4)
289
-
290
- p_attn = torch.nn.functional.softmax(scores, dim=-1)
291
- p_attn = self.drop(p_attn)
292
- output = torch.matmul(p_attn, value)
293
-
294
- if self.window_size is not None:
295
- relative_weights = self._absolute_position_to_relative_position(p_attn)
296
- value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
297
- output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
298
-
299
- output = (output.transpose(2, 3).contiguous().view(b, d, t_t))
300
- return output, p_attn
301
-
302
- def _matmul_with_relative_values(self, x, y):
303
- ret = torch.matmul(x, y.unsqueeze(0))
304
- return ret
305
-
306
- def _matmul_with_relative_keys(self, x, y):
307
- ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
308
- return ret
309
-
310
- def _get_relative_embeddings(self, relative_embeddings, length):
311
- pad_length = max(length - (self.window_size + 1), 0)
312
- slice_start_position = max((self.window_size + 1) - length, 0)
313
- slice_end_position = slice_start_position + 2 * length - 1
314
-
315
- if pad_length > 0: padded_relative_embeddings = torch.nn.functional.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
316
- else: padded_relative_embeddings = relative_embeddings
317
-
318
- used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
319
- return used_relative_embeddings
320
-
321
- def _relative_position_to_absolute_position(self, x):
322
- batch, heads, length, _ = x.size()
323
-
324
- x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
325
-
326
- x_flat = x.view([batch, heads, length * 2 * length])
327
- x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
328
-
329
- x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
330
-
331
- return x_final
332
-
333
- def _absolute_position_to_relative_position(self, x):
334
- batch, heads, length, _ = x.size()
335
- x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
336
- x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
337
- x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
338
- x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
339
- return x_final
340
-
341
- def _attention_bias_proximal(self, length):
342
- r = torch.arange(length, dtype=torch.float32)
343
- diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
344
- return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
345
-
346
-
347
- class FFN(torch.nn.Module):
348
- def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
349
- super().__init__()
350
-
351
- self.in_channels = in_channels
352
- self.out_channels = out_channels
353
- self.filter_channels = filter_channels
354
- self.kernel_size = kernel_size
355
- self.p_dropout = p_dropout
356
- self.activation = activation
357
- self.causal = causal
358
-
359
- if causal: self.padding = self._causal_padding
360
- else: self.padding = self._same_padding
361
-
362
- self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
363
- self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
364
- self.drop = torch.nn.Dropout(p_dropout)
365
-
366
- def forward(self, x, x_mask):
367
- x = self.conv_1(self.padding(x * x_mask))
368
-
369
- if self.activation == "gelu": x = x * torch.sigmoid(1.702 * x)
370
- else: x = torch.relu(x)
371
-
372
- x = self.drop(x)
373
- x = self.conv_2(self.padding(x * x_mask))
374
- return x * x_mask
375
-
376
- def _causal_padding(self, x):
377
- if self.kernel_size == 1: return x
378
-
379
- pad_l = self.kernel_size - 1
380
- pad_r = 0
381
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
382
- x = torch.nn.functional.pad(x, convert_pad_shape(padding))
383
- return x
384
-
385
- def _same_padding(self, x):
386
- if self.kernel_size == 1: return x
387
-
388
- pad_l = (self.kernel_size - 1) // 2
389
- pad_r = self.kernel_size // 2
390
-
391
- padding = [[0, 0], [0, 0], [pad_l, pad_r]]
392
- x = torch.nn.functional.pad(x, convert_pad_shape(padding))
393
-
394
- return x
395
-
396
- class Encoder(torch.nn.Module):
397
- def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
398
- super().__init__()
399
- self.hidden_channels = hidden_channels
400
- self.filter_channels = filter_channels
401
- self.n_heads = n_heads
402
- self.n_layers = n_layers
403
- self.kernel_size = kernel_size
404
- self.p_dropout = p_dropout
405
- self.window_size = window_size
406
-
407
- self.drop = torch.nn.Dropout(p_dropout)
408
- self.attn_layers = torch.nn.ModuleList()
409
- self.norm_layers_1 = torch.nn.ModuleList()
410
- self.ffn_layers = torch.nn.ModuleList()
411
- self.norm_layers_2 = torch.nn.ModuleList()
412
-
413
- for _ in range(self.n_layers):
414
- self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
415
- self.norm_layers_1.append(LayerNorm(hidden_channels))
416
- self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
417
- self.norm_layers_2.append(LayerNorm(hidden_channels))
418
-
419
- def forward(self, x, x_mask):
420
- attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
421
- x = x * x_mask
422
-
423
- for i in range(self.n_layers):
424
- y = self.attn_layers[i](x, x, attn_mask)
425
- y = self.drop(y)
426
- x = self.norm_layers_1[i](x + y)
427
-
428
- y = self.ffn_layers[i](x, x_mask)
429
- y = self.drop(y)
430
- x = self.norm_layers_2[i](x + y)
431
-
432
- x = x * x_mask
433
- return x
434
-
435
-
436
- class TextEncoder(torch.nn.Module):
437
- def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True):
438
- super(TextEncoder, self).__init__()
439
- self.out_channels = out_channels
440
- self.hidden_channels = hidden_channels
441
- self.filter_channels = filter_channels
442
- self.n_heads = n_heads
443
- self.n_layers = n_layers
444
- self.kernel_size = kernel_size
445
- self.p_dropout = float(p_dropout)
446
- self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
447
- self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
448
-
449
- if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
450
-
451
- self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
452
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
453
-
454
- def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor):
455
- if pitch is None: x = self.emb_phone(phone)
456
- else: x = self.emb_phone(phone) + self.emb_pitch(pitch)
457
-
458
- x = x * math.sqrt(self.hidden_channels)
459
- x = self.lrelu(x)
460
- x = torch.transpose(x, 1, -1)
461
- x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
462
- x = self.encoder(x * x_mask, x_mask)
463
- stats = self.proj(x) * x_mask
464
-
465
- m, logs = torch.split(stats, self.out_channels, dim=1)
466
- return m, logs, x_mask
467
-
468
-
469
- class PosteriorEncoder(torch.nn.Module):
470
- def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
471
- super(PosteriorEncoder, self).__init__()
472
- self.in_channels = in_channels
473
- self.out_channels = out_channels
474
- self.hidden_channels = hidden_channels
475
- self.kernel_size = kernel_size
476
- self.dilation_rate = dilation_rate
477
- self.n_layers = n_layers
478
- self.gin_channels = gin_channels
479
-
480
- self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
481
- self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
482
- self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
483
-
484
- def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None):
485
- x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
486
- x = self.pre(x) * x_mask
487
- x = self.enc(x, x_mask, g=g)
488
- stats = self.proj(x) * x_mask
489
- m, logs = torch.split(stats, self.out_channels, dim=1)
490
- z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
491
- return z, m, logs, x_mask
492
-
493
- def remove_weight_norm(self):
494
- self.enc.remove_weight_norm()
495
-
496
- def __prepare_scriptable__(self):
497
- for hook in self.enc._forward_pre_hooks.values():
498
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.enc)
499
-
500
- return self
501
-
502
- class Synthesizer(torch.nn.Module):
503
- def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, **kwargs):
504
- super(Synthesizer, self).__init__()
505
- self.spec_channels = spec_channels
506
- self.inter_channels = inter_channels
507
- self.hidden_channels = hidden_channels
508
- self.filter_channels = filter_channels
509
- self.n_heads = n_heads
510
- self.n_layers = n_layers
511
- self.kernel_size = kernel_size
512
- self.p_dropout = float(p_dropout)
513
- self.resblock = resblock
514
- self.resblock_kernel_sizes = resblock_kernel_sizes
515
- self.resblock_dilation_sizes = resblock_dilation_sizes
516
- self.upsample_rates = upsample_rates
517
- self.upsample_initial_channel = upsample_initial_channel
518
- self.upsample_kernel_sizes = upsample_kernel_sizes
519
- self.segment_size = segment_size
520
- self.gin_channels = gin_channels
521
- self.spk_embed_dim = spk_embed_dim
522
- self.use_f0 = use_f0
523
-
524
- self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0)
525
-
526
- if use_f0: self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"])
527
- else: self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
528
-
529
- self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
530
- self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
531
- self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
532
-
533
- def remove_weight_norm(self):
534
- self.dec.remove_weight_norm()
535
- self.flow.remove_weight_norm()
536
- self.enc_q.remove_weight_norm()
537
-
538
- def __prepare_scriptable__(self):
539
- for hook in self.dec._forward_pre_hooks.values():
540
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.dec)
541
-
542
- for hook in self.flow._forward_pre_hooks.values():
543
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flow)
544
-
545
- if hasattr(self, "enc_q"):
546
- for hook in self.enc_q._forward_pre_hooks.values():
547
- if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.enc_q)
548
-
549
- return self
550
-
551
- @torch.jit.ignore
552
- def forward(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, y: torch.Tensor = None, y_lengths: torch.Tensor = None, ds: Optional[torch.Tensor] = None):
553
- g = self.emb_g(ds).unsqueeze(-1)
554
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
555
-
556
- if y is not None:
557
- z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
558
- z_p = self.flow(z, y_mask, g=g)
559
- z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
560
-
561
- if self.use_f0:
562
- pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
563
- o = self.dec(z_slice, pitchf, g=g)
564
- else: o = self.dec(z_slice, g=g)
565
-
566
- return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
567
- else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
568
-
569
- @torch.jit.export
570
- def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, nsff0: Optional[torch.Tensor] = None, sid: torch.Tensor = None, rate: Optional[torch.Tensor] = None):
571
- g = self.emb_g(sid).unsqueeze(-1)
572
- m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
573
- z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
574
-
575
- if rate is not None:
576
- assert isinstance(rate, torch.Tensor)
577
- head = int(z_p.shape[2] * (1.0 - rate.item()))
578
- z_p = z_p[:, :, head:]
579
- x_mask = x_mask[:, :, head:]
580
-
581
- if self.use_f0: nsff0 = nsff0[:, head:]
582
-
583
- if self.use_f0:
584
- z = self.flow(z_p, x_mask, g=g, reverse=True)
585
- o = self.dec(z * x_mask, nsff0, g=g)
586
- else:
587
- z = self.flow(z_p, x_mask, g=g, reverse=True)
588
- o = self.dec(z * x_mask, g=g)
589
-
590
- return o, x_mask, (z, z_p, m_p, logs_p)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/demucs_separator.py DELETED
@@ -1,340 +0,0 @@
1
- import os
2
- import sys
3
- import yaml
4
- import torch
5
-
6
- import numpy as np
7
- import typing as tp
8
-
9
- from pathlib import Path
10
- from hashlib import sha256
11
-
12
- now_dir = os.getcwd()
13
- sys.path.append(now_dir)
14
-
15
- from main.configs.config import Config
16
- from main.library.uvr5_separator import spec_utils
17
- from main.library.uvr5_separator.demucs.hdemucs import HDemucs
18
- from main.library.uvr5_separator.demucs.states import load_model
19
- from main.library.uvr5_separator.demucs.apply import BagOfModels, Model
20
- from main.library.uvr5_separator.common_separator import CommonSeparator
21
- from main.library.uvr5_separator.demucs.apply import apply_model, demucs_segments
22
-
23
-
24
- translations = Config().translations
25
-
26
- DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
27
-
28
- DEMUCS_2_SOURCE_MAPPER = {
29
- CommonSeparator.INST_STEM: 0,
30
- CommonSeparator.VOCAL_STEM: 1
31
- }
32
-
33
- DEMUCS_4_SOURCE_MAPPER = {
34
- CommonSeparator.BASS_STEM: 0,
35
- CommonSeparator.DRUM_STEM: 1,
36
- CommonSeparator.OTHER_STEM: 2,
37
- CommonSeparator.VOCAL_STEM: 3
38
- }
39
-
40
- DEMUCS_6_SOURCE_MAPPER = {
41
- CommonSeparator.BASS_STEM: 0,
42
- CommonSeparator.DRUM_STEM: 1,
43
- CommonSeparator.OTHER_STEM: 2,
44
- CommonSeparator.VOCAL_STEM: 3,
45
- CommonSeparator.GUITAR_STEM: 4,
46
- CommonSeparator.PIANO_STEM: 5,
47
- }
48
-
49
-
50
- REMOTE_ROOT = Path(__file__).parent / "remote"
51
-
52
- PRETRAINED_MODELS = {
53
- "demucs": "e07c671f",
54
- "demucs48_hq": "28a1282c",
55
- "demucs_extra": "3646af93",
56
- "demucs_quantized": "07afea75",
57
- "tasnet": "beb46fac",
58
- "tasnet_extra": "df3777b2",
59
- "demucs_unittest": "09ebc15f",
60
- }
61
-
62
-
63
- sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
64
-
65
- AnyModel = tp.Union[Model, BagOfModels]
66
-
67
-
68
- class DemucsSeparator(CommonSeparator):
69
- def __init__(self, common_config, arch_config):
70
- super().__init__(config=common_config)
71
-
72
- self.segment_size = arch_config.get("segment_size", "Default")
73
- self.shifts = arch_config.get("shifts", 2)
74
- self.overlap = arch_config.get("overlap", 0.25)
75
- self.segments_enabled = arch_config.get("segments_enabled", True)
76
-
77
- self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
78
- self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
79
-
80
- self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
81
-
82
- self.audio_file_path = None
83
- self.audio_file_base = None
84
- self.demucs_model_instance = None
85
-
86
- self.logger.info(translations["start_demucs"])
87
-
88
- def separate(self, audio_file_path):
89
- self.logger.debug(translations["start_separator"])
90
-
91
- source = None
92
- stem_source = None
93
-
94
- inst_source = {}
95
-
96
- self.audio_file_path = audio_file_path
97
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
98
-
99
- self.logger.debug(translations["prepare_mix"])
100
- mix = self.prepare_mix(self.audio_file_path)
101
-
102
- self.logger.debug(translations["demix"].format(shape=mix.shape))
103
-
104
- self.logger.debug(translations["cancel_mix"])
105
-
106
- self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE)
107
- self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
108
- self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance)
109
- self.demucs_model_instance.to(self.torch_device)
110
- self.demucs_model_instance.eval()
111
-
112
- self.logger.debug(translations["model_review"])
113
-
114
- source = self.demix_demucs(mix)
115
-
116
- del self.demucs_model_instance
117
- self.clear_gpu_cache()
118
- self.logger.debug(translations["del_gpu_cache_after_demix"])
119
-
120
- output_files = []
121
- self.logger.debug(translations["process_output_file"])
122
-
123
- if isinstance(inst_source, np.ndarray):
124
- self.logger.debug(translations["process_ver"])
125
- source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]])
126
- inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape
127
- source = inst_source
128
-
129
- if isinstance(source, np.ndarray):
130
- source_length = len(source)
131
- self.logger.debug(translations["source_length"].format(source_length=source_length))
132
-
133
- match source_length:
134
- case 2:
135
- self.logger.debug(translations["set_map"].format(part="2"))
136
- self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
137
- case 6:
138
- self.logger.debug(translations["set_map"].format(part="6"))
139
- self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER
140
- case _:
141
- self.logger.debug(translations["set_map"].format(part="2"))
142
- self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
143
-
144
- self.logger.debug(translations["process_all_part"])
145
-
146
- for stem_name, stem_value in self.demucs_source_map.items():
147
- if self.output_single_stem is not None:
148
- if stem_name.lower() != self.output_single_stem.lower():
149
- self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
150
- continue
151
-
152
- stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
153
- stem_source = source[stem_value].T
154
-
155
- self.final_process(stem_path, stem_source, stem_name)
156
- output_files.append(stem_path)
157
-
158
- return output_files
159
-
160
- def demix_demucs(self, mix):
161
- self.logger.debug(translations["starting_demix_demucs"])
162
-
163
- processed = {}
164
- mix = torch.tensor(mix, dtype=torch.float32)
165
- ref = mix.mean(0)
166
- mix = (mix - ref.mean()) / ref.std()
167
- mix_infer = mix
168
-
169
- with torch.no_grad():
170
- self.logger.debug(translations["model_infer"])
171
- sources = apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
172
-
173
- sources = (sources * ref.std() + ref.mean()).cpu().numpy()
174
- sources[[0, 1]] = sources[[1, 0]]
175
-
176
- processed[mix] = sources[:, :, 0:None].copy()
177
-
178
- sources = list(processed.values())
179
- sources = [s[:, :, 0:None] for s in sources]
180
- sources = np.concatenate(sources, axis=-1)
181
- return sources
182
-
183
-
184
- class ModelOnlyRepo:
185
- def has_model(self, sig: str) -> bool:
186
- raise NotImplementedError()
187
-
188
- def get_model(self, sig: str) -> Model:
189
- raise NotImplementedError()
190
-
191
-
192
- class RemoteRepo(ModelOnlyRepo):
193
- def __init__(self, models: tp.Dict[str, str]):
194
- self._models = models
195
-
196
- def has_model(self, sig: str) -> bool:
197
- return sig in self._models
198
-
199
- def get_model(self, sig: str) -> Model:
200
- try:
201
- url = self._models[sig]
202
- except KeyError:
203
- raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
204
-
205
- pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
206
- return load_model(pkg)
207
-
208
-
209
- class LocalRepo(ModelOnlyRepo):
210
- def __init__(self, root: Path):
211
- self.root = root
212
- self.scan()
213
-
214
- def scan(self):
215
- self._models = {}
216
- self._checksums = {}
217
-
218
- for file in self.root.iterdir():
219
- if file.suffix == ".th":
220
- if "-" in file.stem:
221
- xp_sig, checksum = file.stem.split("-")
222
- self._checksums[xp_sig] = checksum
223
- else: xp_sig = file.stem
224
-
225
- if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
226
-
227
- self._models[xp_sig] = file
228
-
229
- def has_model(self, sig: str) -> bool:
230
- return sig in self._models
231
-
232
- def get_model(self, sig: str) -> Model:
233
- try:
234
- file = self._models[sig]
235
- except KeyError:
236
- raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
237
-
238
- if sig in self._checksums: check_checksum(file, self._checksums[sig])
239
-
240
- return load_model(file)
241
-
242
-
243
- class BagOnlyRepo:
244
- def __init__(self, root: Path, model_repo: ModelOnlyRepo):
245
- self.root = root
246
- self.model_repo = model_repo
247
- self.scan()
248
-
249
- def scan(self):
250
- self._bags = {}
251
-
252
- for file in self.root.iterdir():
253
- if file.suffix == ".yaml": self._bags[file.stem] = file
254
-
255
- def has_model(self, name: str) -> bool:
256
- return name in self._bags
257
-
258
- def get_model(self, name: str) -> BagOfModels:
259
- try:
260
- yaml_file = self._bags[name]
261
- except KeyError:
262
- raise RuntimeError(translations["name_not_pretrained"].format(name=name))
263
-
264
- bag = yaml.safe_load(open(yaml_file))
265
- signatures = bag["models"]
266
- models = [self.model_repo.get_model(sig) for sig in signatures]
267
-
268
- weights = bag.get("weights")
269
- segment = bag.get("segment")
270
-
271
- return BagOfModels(models, weights, segment)
272
-
273
-
274
- class AnyModelRepo:
275
- def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
276
- self.model_repo = model_repo
277
- self.bag_repo = bag_repo
278
-
279
- def has_model(self, name_or_sig: str) -> bool:
280
- return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
281
-
282
- def get_model(self, name_or_sig: str) -> AnyModel:
283
- if self.model_repo.has_model(name_or_sig): return self.model_repo.get_model(name_or_sig)
284
- else: return self.bag_repo.get_model(name_or_sig)
285
-
286
-
287
- def check_checksum(path: Path, checksum: str):
288
- sha = sha256()
289
-
290
- with open(path, "rb") as file:
291
- while 1:
292
- buf = file.read(2**20)
293
- if not buf: break
294
-
295
- sha.update(buf)
296
-
297
- actual_checksum = sha.hexdigest()[: len(checksum)]
298
-
299
- if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
300
-
301
-
302
- def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
303
- root: str = ""
304
- models: tp.Dict[str, str] = {}
305
-
306
- for line in remote_file_list.read_text().split("\n"):
307
- line = line.strip()
308
-
309
- if line.startswith("#"): continue
310
- elif line.startswith("root:"): root = line.split(":", 1)[1].strip()
311
- else:
312
- sig = line.split("-", 1)[0]
313
- assert sig not in models
314
-
315
- models[sig] = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" + root + line
316
-
317
- return models
318
-
319
-
320
- def get_demucs_model(name: str, repo: tp.Optional[Path] = None):
321
- if name == "demucs_unittest": return HDemucs(channels=4, sources=DEMUCS_4_SOURCE)
322
-
323
- model_repo: ModelOnlyRepo
324
-
325
- if repo is None:
326
- models = _parse_remote_files(REMOTE_ROOT / "files.txt")
327
- model_repo = RemoteRepo(models)
328
- bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
329
- else:
330
- if not repo.is_dir(): print(translations["repo_must_be_folder"].format(repo=repo))
331
-
332
- model_repo = LocalRepo(repo)
333
- bag_repo = BagOnlyRepo(repo, model_repo)
334
-
335
- any_repo = AnyModelRepo(model_repo, bag_repo)
336
-
337
- model = any_repo.get_model(name)
338
- model.eval()
339
-
340
- return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/architectures/mdx_separator.py DELETED
@@ -1,370 +0,0 @@
1
- import os
2
- import sys
3
- import onnx
4
- import torch
5
- import platform
6
- import onnx2torch
7
-
8
- import numpy as np
9
- import onnxruntime as ort
10
-
11
- from tqdm import tqdm
12
-
13
- now_dir = os.getcwd()
14
- sys.path.append(now_dir)
15
-
16
- from main.configs.config import Config
17
- from main.library.uvr5_separator import spec_utils
18
- from main.library.uvr5_separator.common_separator import CommonSeparator
19
-
20
-
21
- translations = Config().translations
22
-
23
- class MDXSeparator(CommonSeparator):
24
- def __init__(self, common_config, arch_config):
25
- super().__init__(config=common_config)
26
-
27
- self.segment_size = arch_config.get("segment_size")
28
- self.overlap = arch_config.get("overlap")
29
- self.batch_size = arch_config.get("batch_size", 1)
30
- self.hop_length = arch_config.get("hop_length")
31
- self.enable_denoise = arch_config.get("enable_denoise")
32
- self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
33
- self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
34
- self.compensate = self.model_data["compensate"]
35
- self.dim_f = self.model_data["mdx_dim_f_set"]
36
- self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
37
- self.n_fft = self.model_data["mdx_n_fft_scale_set"]
38
- self.config_yaml = self.model_data.get("config_yaml", None)
39
- self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
40
- self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
41
- self.load_model()
42
- self.n_bins = 0
43
- self.trim = 0
44
- self.chunk_size = 0
45
- self.gen_size = 0
46
- self.stft = None
47
- self.primary_source = None
48
- self.secondary_source = None
49
- self.audio_file_path = None
50
- self.audio_file_base = None
51
-
52
-
53
- def load_model(self):
54
- self.logger.debug(translations["load_model_onnx"])
55
-
56
- if self.segment_size == self.dim_t:
57
- ort_session_options = ort.SessionOptions()
58
-
59
- if self.log_level > 10: ort_session_options.log_severity_level = 3
60
- else: ort_session_options.log_severity_level = 0
61
-
62
- ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
63
- self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
64
- self.logger.debug(translations["load_model_onnx_success"])
65
- else:
66
- if platform.system() == 'Windows':
67
- onnx_model = onnx.load(self.model_path)
68
- self.model_run = onnx2torch.convert(onnx_model)
69
- else: self.model_run = onnx2torch.convert(self.model_path)
70
-
71
- self.model_run.to(self.torch_device).eval()
72
- self.logger.warning(translations["onnx_to_pytorch"])
73
-
74
- def separate(self, audio_file_path):
75
- self.audio_file_path = audio_file_path
76
- self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
77
-
78
- self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
79
- mix = self.prepare_mix(self.audio_file_path)
80
-
81
- self.logger.debug(translations["normalization_demix"])
82
- mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
83
-
84
- source = self.demix(mix)
85
- self.logger.debug(translations["mix_success"])
86
-
87
- output_files = []
88
- self.logger.debug(translations["process_output_file"])
89
-
90
- if not isinstance(self.primary_source, np.ndarray):
91
- self.logger.debug(translations["primary_source"])
92
- self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
93
- if not isinstance(self.secondary_source, np.ndarray):
94
- self.logger.debug(translations["secondary_source"])
95
- raw_mix = self.demix(mix, is_match_mix=True)
96
-
97
- if self.invert_using_spec:
98
- self.logger.debug(translations["invert_using_spec"])
99
- self.secondary_source = spec_utils.invert_stem(raw_mix, source)
100
- else:
101
- self.logger.debug(translations["invert_using_spec_2"])
102
- self.secondary_source = mix.T - source.T
103
-
104
- if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
105
- self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
106
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
107
- self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
108
- output_files.append(self.secondary_stem_output_path)
109
-
110
- if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
111
- self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
112
-
113
- if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
114
-
115
- self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
116
- self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
117
- output_files.append(self.primary_stem_output_path)
118
-
119
- return output_files
120
-
121
- def initialize_model_settings(self):
122
- self.logger.debug(translations["starting_model"])
123
-
124
- self.n_bins = self.n_fft // 2 + 1
125
- self.trim = self.n_fft // 2
126
-
127
- self.chunk_size = self.hop_length * (self.segment_size - 1)
128
- self.gen_size = self.chunk_size - 2 * self.trim
129
-
130
- self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
131
-
132
- self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
133
- self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
134
-
135
- def initialize_mix(self, mix, is_ckpt=False):
136
- self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
137
-
138
- if mix.shape[0] != 2:
139
- error_message = translations["!=2"].format(shape=mix.shape[0])
140
- self.logger.error(error_message)
141
- raise ValueError(error_message)
142
-
143
- if is_ckpt:
144
- self.logger.debug(translations["process_check"])
145
- pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
146
- self.logger.debug(f"{translations['cache']}: {pad}")
147
-
148
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
149
-
150
- num_chunks = mixture.shape[-1] // self.gen_size
151
- self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
152
-
153
- mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
154
- else:
155
- self.logger.debug(translations["process_no_check"])
156
- mix_waves = []
157
- n_sample = mix.shape[1]
158
-
159
- pad = self.gen_size - n_sample % self.gen_size
160
- self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
161
-
162
- mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
163
- self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
164
-
165
- i = 0
166
- while i < n_sample + pad:
167
- waves = np.array(mix_p[:, i : i + self.chunk_size])
168
- mix_waves.append(waves)
169
-
170
- self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
171
- i += self.gen_size
172
-
173
- mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
174
- self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
175
-
176
- return mix_waves_tensor, pad
177
-
178
- def demix(self, mix, is_match_mix=False):
179
- self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
180
- self.initialize_model_settings()
181
-
182
- org_mix = mix
183
- self.logger.debug(f"{translations['mix_shape']}: {org_mix.shape}")
184
-
185
- tar_waves_ = []
186
-
187
- if is_match_mix:
188
- chunk_size = self.hop_length * (self.segment_size - 1)
189
- overlap = 0.02
190
- self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
191
- else:
192
- chunk_size = self.chunk_size
193
- overlap = self.overlap
194
- self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
195
-
196
-
197
- gen_size = chunk_size - 2 * self.trim
198
- self.logger.debug(f"{translations['calc_size']}: {gen_size}")
199
-
200
-
201
- pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
202
-
203
- mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
204
- self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
205
-
206
- step = int((1 - overlap) * chunk_size)
207
- self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
208
-
209
- result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
210
- divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
211
-
212
- total = 0
213
- total_chunks = (mixture.shape[-1] + step - 1) // step
214
- self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
215
-
216
- for i in tqdm(range(0, mixture.shape[-1], step)):
217
- total += 1
218
- start = i
219
- end = min(i + chunk_size, mixture.shape[-1])
220
- self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
221
-
222
- chunk_size_actual = end - start
223
- window = None
224
-
225
- if overlap != 0:
226
- window = np.hanning(chunk_size_actual)
227
- window = np.tile(window[None, None, :], (1, 2, 1))
228
- self.logger.debug(translations["window"])
229
-
230
- mix_part_ = mixture[:, start:end]
231
-
232
- if end != i + chunk_size:
233
- pad_size = (i + chunk_size) - end
234
- mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
235
-
236
- mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
237
-
238
- mix_waves = mix_part.split(self.batch_size)
239
- total_batches = len(mix_waves)
240
- self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
241
-
242
- with torch.no_grad():
243
- batches_processed = 0
244
- for mix_wave in mix_waves:
245
- batches_processed += 1
246
- self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
247
-
248
- tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
249
-
250
- if window is not None:
251
- tar_waves[..., :chunk_size_actual] *= window
252
- divider[..., start:end] += window
253
- else: divider[..., start:end] += 1
254
-
255
- result[..., start:end] += tar_waves[..., : end - start]
256
-
257
-
258
- self.logger.debug(translations["normalization_2"])
259
- tar_waves = result / divider
260
- tar_waves_.append(tar_waves)
261
-
262
- tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
263
- tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
264
-
265
- source = tar_waves[:, 0:None]
266
- self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
267
-
268
- if not is_match_mix:
269
- source *= self.compensate
270
- self.logger.debug(translations["mix_match"])
271
-
272
- self.logger.debug(translations["mix_success"])
273
- return source
274
-
275
- def run_model(self, mix, is_match_mix=False):
276
- spek = self.stft(mix.to(self.torch_device))
277
- self.logger.debug(translations["stft_2"].format(shape=spek.shape))
278
-
279
- spek[:, :, :3, :] *= 0
280
-
281
- if is_match_mix:
282
- spec_pred = spek.cpu().numpy()
283
- self.logger.debug(translations["is_match_mix"])
284
- else:
285
- if self.enable_denoise:
286
- spec_pred_neg = self.model_run(-spek)
287
- spec_pred_pos = self.model_run(spek)
288
- spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
289
- self.logger.debug(translations["enable_denoise"])
290
- else:
291
- spec_pred = self.model_run(spek)
292
- self.logger.debug(translations["no_denoise"])
293
-
294
- result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
295
- self.logger.debug(f"{translations['stft']}: {result.shape}")
296
-
297
- return result
298
-
299
- class STFT:
300
- def __init__(self, logger, n_fft, hop_length, dim_f, device):
301
- self.logger = logger
302
- self.n_fft = n_fft
303
- self.hop_length = hop_length
304
- self.dim_f = dim_f
305
- self.device = device
306
- self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
307
-
308
- def __call__(self, input_tensor):
309
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
310
-
311
- if is_non_standard_device: input_tensor = input_tensor.cpu()
312
-
313
- stft_window = self.hann_window.to(input_tensor.device)
314
-
315
- batch_dimensions = input_tensor.shape[:-2]
316
- channel_dim, time_dim = input_tensor.shape[-2:]
317
-
318
- reshaped_tensor = input_tensor.reshape([-1, time_dim])
319
- stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
320
-
321
- permuted_stft_output = stft_output.permute([0, 3, 1, 2])
322
-
323
- final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
324
-
325
- if is_non_standard_device: final_output = final_output.to(self.device)
326
-
327
- return final_output[..., : self.dim_f, :]
328
-
329
- def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
330
- freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
331
- padded_tensor = torch.cat([input_tensor, freq_padding], -2)
332
-
333
- return padded_tensor
334
-
335
- def calculate_inverse_dimensions(self, input_tensor):
336
- batch_dimensions = input_tensor.shape[:-3]
337
- channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
338
-
339
- num_freq_bins = self.n_fft // 2 + 1
340
-
341
- return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
342
-
343
- def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
344
- reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
345
- flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
346
- permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
347
- complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
348
-
349
- return complex_tensor
350
-
351
- def inverse(self, input_tensor):
352
- is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
353
-
354
- if is_non_standard_device: input_tensor = input_tensor.cpu()
355
-
356
- stft_window = self.hann_window.to(input_tensor.device)
357
-
358
- batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
359
-
360
- padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
361
-
362
- complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
363
-
364
- istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
365
-
366
- final_output = istft_result.reshape([*batch_dimensions, 2, -1])
367
-
368
- if is_non_standard_device: final_output = final_output.to(self.device)
369
-
370
- return final_output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/FCPE.py DELETED
@@ -1,600 +0,0 @@
1
- import os
2
- import math
3
- import torch
4
- import librosa
5
- import numpy as np
6
- import torch.nn as nn
7
- import soundfile as sf
8
- import torch.utils.data
9
- import torch.nn.functional as F
10
-
11
- from torch import nn
12
- from typing import Union
13
- from functools import partial
14
- from einops import rearrange, repeat
15
- from torchaudio.transforms import Resample
16
- from local_attention import LocalAttention
17
- from librosa.filters import mel as librosa_mel_fn
18
- from torch.nn.utils.parametrizations import weight_norm
19
-
20
- os.environ["LRU_CACHE_CAPACITY"] = "3"
21
-
22
- def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
23
- try:
24
- data, sample_rate = sf.read(full_path, always_2d=True)
25
- except Exception as e:
26
- print(f"{full_path}: {e}")
27
- if return_empty_on_exception: return [], sample_rate or target_sr or 48000
28
- else: raise
29
-
30
- data = data[:, 0] if len(data.shape) > 1 else data
31
- assert len(data) > 2
32
-
33
- max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
34
- max_mag = ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
35
- data = torch.FloatTensor(data.astype(np.float32)) / max_mag
36
-
37
- if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
38
- if target_sr is not None and sample_rate != target_sr:
39
- data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
40
- sample_rate = target_sr
41
-
42
- return data, sample_rate
43
-
44
- def dynamic_range_compression(x, C=1, clip_val=1e-5):
45
- return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
46
- def dynamic_range_decompression(x, C=1):
47
- return np.exp(x) / C
48
- def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
49
- return torch.log(torch.clamp(x, min=clip_val) * C)
50
- def dynamic_range_decompression_torch(x, C=1):
51
- return torch.exp(x) / C
52
-
53
- class STFT:
54
- def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
55
- self.target_sr = sr
56
- self.n_mels = n_mels
57
- self.n_fft = n_fft
58
- self.win_size = win_size
59
- self.hop_length = hop_length
60
- self.fmin = fmin
61
- self.fmax = fmax
62
- self.clip_val = clip_val
63
- self.mel_basis = {}
64
- self.hann_window = {}
65
-
66
- def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
67
- sample_rate = self.target_sr
68
- n_mels = self.n_mels
69
- n_fft = self.n_fft
70
- win_size = self.win_size
71
- hop_length = self.hop_length
72
- fmin = self.fmin
73
- fmax = self.fmax
74
- clip_val = self.clip_val
75
-
76
- factor = 2 ** (keyshift / 12)
77
- n_fft_new = int(np.round(n_fft * factor))
78
- win_size_new = int(np.round(win_size * factor))
79
- hop_length_new = int(np.round(hop_length * speed))
80
-
81
- mel_basis = self.mel_basis if not train else {}
82
- hann_window = self.hann_window if not train else {}
83
- mel_basis_key = str(fmax) + "_" + str(y.device)
84
-
85
- if mel_basis_key not in mel_basis:
86
- mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
87
- mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
88
-
89
- keyshift_key = str(keyshift) + "_" + str(y.device)
90
-
91
- if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
92
-
93
- pad_left = (win_size_new - hop_length_new) // 2
94
- pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
95
- mode = "reflect" if pad_right < y.size(-1) else "constant"
96
- y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
97
- y = y.squeeze(1)
98
-
99
- spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
100
- spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
101
-
102
- if keyshift != 0:
103
- size = n_fft // 2 + 1
104
- resize = spec.size(1)
105
- spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :])
106
- spec = spec * win_size / win_size_new
107
- spec = torch.matmul(mel_basis[mel_basis_key], spec)
108
- spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
109
- return spec
110
-
111
- def __call__(self, audiopath):
112
- audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
113
- spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
114
- return spect
115
-
116
- stft = STFT()
117
-
118
- def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
119
- b, h, *_ = data.shape
120
-
121
- data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
122
-
123
- ratio = projection_matrix.shape[0] ** -0.5
124
- projection = repeat(projection_matrix, "j d -> b h j d", b=b, h=h)
125
- projection = projection.type_as(data)
126
- data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), projection)
127
-
128
- diag_data = data**2
129
- diag_data = torch.sum(diag_data, dim=-1)
130
- diag_data = (diag_data / 2.0) * (data_normalizer**2)
131
- diag_data = diag_data.unsqueeze(dim=-1)
132
-
133
- if is_query: data_dash = ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
134
- else: data_dash = ratio * (torch.exp(data_dash - diag_data + eps))
135
-
136
- return data_dash.type_as(data)
137
-
138
- def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
139
- unstructured_block = torch.randn((cols, cols), device=device)
140
- q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
141
- q, r = map(lambda t: t.to(device), (q, r))
142
-
143
- if qr_uniform_q:
144
- d = torch.diag(r, 0)
145
- q *= d.sign()
146
-
147
- return q.t()
148
-
149
- def exists(val):
150
- return val is not None
151
- def empty(tensor):
152
- return tensor.numel() == 0
153
- def default(val, d):
154
- return val if exists(val) else d
155
- def cast_tuple(val):
156
- return (val,) if not isinstance(val, tuple) else val
157
-
158
- class PCmer(nn.Module):
159
- def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
160
- super().__init__()
161
- self.num_layers = num_layers
162
- self.num_heads = num_heads
163
- self.dim_model = dim_model
164
- self.dim_values = dim_values
165
- self.dim_keys = dim_keys
166
- self.residual_dropout = residual_dropout
167
- self.attention_dropout = attention_dropout
168
-
169
- self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
170
-
171
- def forward(self, phone, mask=None):
172
- for layer in self._layers:
173
- phone = layer(phone, mask)
174
-
175
- return phone
176
-
177
- class _EncoderLayer(nn.Module):
178
- def __init__(self, parent: PCmer):
179
- super().__init__()
180
- self.conformer = ConformerConvModule(parent.dim_model)
181
- self.norm = nn.LayerNorm(parent.dim_model)
182
- self.dropout = nn.Dropout(parent.residual_dropout)
183
- self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
184
-
185
- def forward(self, phone, mask=None):
186
- phone = phone + (self.attn(self.norm(phone), mask=mask))
187
- phone = phone + (self.conformer(phone))
188
- return phone
189
-
190
- def calc_same_padding(kernel_size):
191
- pad = kernel_size // 2
192
- return (pad, pad - (kernel_size + 1) % 2)
193
-
194
- class Swish(nn.Module):
195
- def forward(self, x):
196
- return x * x.sigmoid()
197
-
198
- class Transpose(nn.Module):
199
- def __init__(self, dims):
200
- super().__init__()
201
- assert len(dims) == 2, "dims == 2"
202
- self.dims = dims
203
-
204
- def forward(self, x):
205
- return x.transpose(*self.dims)
206
-
207
- class GLU(nn.Module):
208
- def __init__(self, dim):
209
- super().__init__()
210
- self.dim = dim
211
-
212
- def forward(self, x):
213
- out, gate = x.chunk(2, dim=self.dim)
214
- return out * gate.sigmoid()
215
-
216
- class DepthWiseConv1d(nn.Module):
217
- def __init__(self, chan_in, chan_out, kernel_size, padding):
218
- super().__init__()
219
- self.padding = padding
220
- self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
221
-
222
- def forward(self, x):
223
- x = F.pad(x, self.padding)
224
- return self.conv(x)
225
-
226
- class ConformerConvModule(nn.Module):
227
- def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
228
- super().__init__()
229
-
230
- inner_dim = dim * expansion_factor
231
- padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
232
-
233
- self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
234
-
235
- def forward(self, x):
236
- return self.net(x)
237
-
238
- def linear_attention(q, k, v):
239
- if v is None:
240
- out = torch.einsum("...ed,...nd->...ne", k, q)
241
- return out
242
- else:
243
- k_cumsum = k.sum(dim=-2)
244
- D_inv = 1.0 / (torch.einsum("...nd,...d->...n", q, k_cumsum.type_as(q)) + 1e-8)
245
- context = torch.einsum("...nd,...ne->...de", k, v)
246
- out = torch.einsum("...de,...nd,...n->...ne", context, q, D_inv)
247
- return out
248
-
249
- def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
250
- nb_full_blocks = int(nb_rows / nb_columns)
251
- block_list = []
252
-
253
- for _ in range(nb_full_blocks):
254
- q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
255
- block_list.append(q)
256
-
257
- remaining_rows = nb_rows - nb_full_blocks * nb_columns
258
-
259
- if remaining_rows > 0:
260
- q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
261
- block_list.append(q[:remaining_rows])
262
-
263
- final_matrix = torch.cat(block_list)
264
-
265
- if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
266
- elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
267
- else: raise ValueError(f"Chia không hợp lệ {scaling}")
268
-
269
- return torch.diag(multiplier) @ final_matrix
270
-
271
- class FastAttention(nn.Module):
272
- def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
273
- super().__init__()
274
- nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
275
-
276
- self.dim_heads = dim_heads
277
- self.nb_features = nb_features
278
- self.ortho_scaling = ortho_scaling
279
-
280
- self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
281
- projection_matrix = self.create_projection()
282
- self.register_buffer("projection_matrix", projection_matrix)
283
-
284
- self.generalized_attention = generalized_attention
285
- self.kernel_fn = kernel_fn
286
- self.no_projection = no_projection
287
- self.causal = causal
288
-
289
- @torch.no_grad()
290
- def redraw_projection_matrix(self):
291
- projections = self.create_projection()
292
- self.projection_matrix.copy_(projections)
293
- del projections
294
-
295
- def forward(self, q, k, v):
296
- device = q.device
297
-
298
- if self.no_projection:
299
- q = q.softmax(dim=-1)
300
- k = torch.exp(k) if self.causal else k.softmax(dim=-2)
301
- else:
302
- create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device)
303
-
304
- q = create_kernel(q, is_query=True)
305
- k = create_kernel(k, is_query=False)
306
-
307
- attn_fn = linear_attention if not self.causal else self.causal_linear_fn
308
-
309
- if v is None:
310
- out = attn_fn(q, k, None)
311
- return out
312
- else:
313
- out = attn_fn(q, k, v)
314
- return out
315
-
316
- class SelfAttention(nn.Module):
317
- def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
318
- super().__init__()
319
- assert dim % heads == 0
320
- dim_head = default(dim_head, dim // heads)
321
- inner_dim = dim_head * heads
322
- self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
323
- self.heads = heads
324
- self.global_heads = heads - local_heads
325
- self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
326
- self.to_q = nn.Linear(dim, inner_dim)
327
- self.to_k = nn.Linear(dim, inner_dim)
328
- self.to_v = nn.Linear(dim, inner_dim)
329
- self.to_out = nn.Linear(inner_dim, dim)
330
- self.dropout = nn.Dropout(dropout)
331
-
332
- @torch.no_grad()
333
- def redraw_projection_matrix(self):
334
- self.fast_attention.redraw_projection_matrix()
335
-
336
- def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
337
- _, _, _, h, gh = *x.shape, self.heads, self.global_heads
338
-
339
- cross_attend = exists(context)
340
- context = default(context, x)
341
- context_mask = default(context_mask, mask) if not cross_attend else context_mask
342
- q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
343
-
344
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
345
- (q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
346
-
347
- attn_outs = []
348
-
349
- if not empty(q):
350
- if exists(context_mask):
351
- global_mask = context_mask[:, None, :, None]
352
- v.masked_fill_(~global_mask, 0.0)
353
-
354
- if cross_attend: pass
355
- else: out = self.fast_attention(q, k, v)
356
-
357
- attn_outs.append(out)
358
-
359
- if not empty(lq):
360
- assert (not cross_attend), "not cross_attend"
361
- out = self.local_attn(lq, lk, lv, input_mask=mask)
362
- attn_outs.append(out)
363
-
364
- out = torch.cat(attn_outs, dim=1)
365
- out = rearrange(out, "b h n d -> b n (h d)")
366
- out = self.to_out(out)
367
- return self.dropout(out)
368
-
369
- def l2_regularization(model, l2_alpha):
370
- l2_loss = []
371
- for module in model.modules():
372
- if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
373
-
374
- return l2_alpha * sum(l2_loss)
375
-
376
- class _FCPE(nn.Module):
377
- def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
378
- super().__init__()
379
- if use_siren: raise ValueError("Siren not support")
380
- if use_full: raise ValueError("Model full not support")
381
-
382
- self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
383
- self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
384
- self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
385
- self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
386
- self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
387
- self.f0_max = f0_max if (f0_max is not None) else 1975.5
388
- self.f0_min = f0_min if (f0_min is not None) else 32.70
389
- self.confidence = confidence if (confidence is not None) else False
390
- self.threshold = threshold if (threshold is not None) else 0.05
391
- self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
392
- self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
393
- self.register_buffer("cent_table", self.cent_table_b)
394
-
395
- _leaky = nn.LeakyReLU()
396
-
397
- self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), _leaky, nn.Conv1d(n_chans, n_chans, 3, 1, 1))
398
- self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
399
- self.norm = nn.LayerNorm(n_chans)
400
- self.n_out = out_dims
401
- self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
402
-
403
- def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
404
- if cdecoder == "argmax": self.cdecoder = self.cents_decoder
405
- elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
406
-
407
- x = (self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)
408
- x = self.decoder(x)
409
- x = self.norm(x)
410
- x = self.dense_out(x)
411
- x = torch.sigmoid(x)
412
-
413
- if not infer:
414
- gt_cent_f0 = self.f0_to_cent(gt_f0)
415
- gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0)
416
- loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0)
417
-
418
- if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
419
-
420
- x = loss_all
421
-
422
- if infer:
423
- x = self.cdecoder(x)
424
- x = self.cent_to_f0(x)
425
- x = (1 + x / 700).log() if not return_hz_f0 else x
426
-
427
- return x
428
-
429
- def cents_decoder(self, y, mask=True):
430
- B, N, _ = y.size()
431
- ci = self.cent_table[None, None, :].expand(B, N, -1)
432
- rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
433
- if mask:
434
- confident = torch.max(y, dim=-1, keepdim=True)[0]
435
- confident_mask = torch.ones_like(confident)
436
- confident_mask[confident <= self.threshold] = float("-INF")
437
- rtn = rtn * confident_mask
438
-
439
- return (rtn, confident) if self.confidence else rtn
440
-
441
- def cents_local_decoder(self, y, mask=True):
442
- B, N, _ = y.size()
443
- ci = self.cent_table[None, None, :].expand(B, N, -1)
444
- confident, max_index = torch.max(y, dim=-1, keepdim=True)
445
- local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
446
- local_argmax_index = torch.clamp(local_argmax_index, 0, self.n_out - 1)
447
- ci_l = torch.gather(ci, -1, local_argmax_index)
448
- y_l = torch.gather(y, -1, local_argmax_index)
449
- rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
450
-
451
- if mask:
452
- confident_mask = torch.ones_like(confident)
453
- confident_mask[confident <= self.threshold] = float("-INF")
454
- rtn = rtn * confident_mask
455
-
456
- return (rtn, confident) if self.confidence else rtn
457
-
458
- def cent_to_f0(self, cent):
459
- return 10.0 * 2 ** (cent / 1200.0)
460
-
461
- def f0_to_cent(self, f0):
462
- return 1200.0 * torch.log2(f0 / 10.0)
463
-
464
- def gaussian_blurred_cent(self, cents):
465
- mask = (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0)))
466
- B, N, _ = cents.size()
467
- ci = self.cent_table[None, None, :].expand(B, N, -1)
468
- return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
469
-
470
- class FCPEInfer:
471
- def __init__(self, model_path, device=None, dtype=torch.float32):
472
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
473
-
474
- self.device = device
475
- ckpt = torch.load(model_path, map_location=torch.device(self.device))
476
- self.args = DotDict(ckpt["config"])
477
- self.dtype = dtype
478
- model = _FCPE(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
479
- model.to(self.device).to(self.dtype)
480
- model.load_state_dict(ckpt["model"])
481
- model.eval()
482
- self.model = model
483
- self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
484
-
485
- @torch.no_grad()
486
- def __call__(self, audio, sr, threshold=0.05):
487
- self.model.threshold = threshold
488
- audio = audio[None, :]
489
- mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
490
- f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
491
- return f0
492
-
493
- class Wav2Mel:
494
- def __init__(self, args, device=None, dtype=torch.float32):
495
- self.sample_rate = args.mel.sampling_rate
496
- self.hop_size = args.mel.hop_size
497
-
498
- if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
499
-
500
- self.device = device
501
- self.dtype = dtype
502
- self.stft = STFT(args.mel.sampling_rate, args.mel.num_mels, args.mel.n_fft, args.mel.win_size, args.mel.hop_size, args.mel.fmin, args.mel.fmax)
503
- self.resample_kernel = {}
504
-
505
- def extract_nvstft(self, audio, keyshift=0, train=False):
506
- mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
507
- return mel
508
-
509
- def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
510
- audio = audio.to(self.dtype).to(self.device)
511
-
512
- if sample_rate == self.sample_rate: audio_res = audio
513
- else:
514
- key_str = str(sample_rate)
515
-
516
- if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
517
-
518
- self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
519
- audio_res = self.resample_kernel[key_str](audio)
520
-
521
- mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
522
- n_frames = int(audio.shape[1] // self.hop_size) + 1
523
- mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
524
- mel = mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
525
- return mel
526
-
527
- def __call__(self, audio, sample_rate, keyshift=0, train=False):
528
- return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
529
-
530
- class DotDict(dict):
531
- def __getattr__(*args):
532
- val = dict.get(*args)
533
- return DotDict(val) if type(val) is dict else val
534
-
535
- __setattr__ = dict.__setitem__
536
- __delattr__ = dict.__delitem__
537
-
538
- class F0Predictor(object):
539
- def compute_f0(self, wav, p_len): pass
540
- def compute_f0_uv(self, wav, p_len): pass
541
-
542
- class FCPE(F0Predictor):
543
- def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05):
544
- self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype)
545
- self.hop_length = hop_length
546
- self.f0_min = f0_min
547
- self.f0_max = f0_max
548
- self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
549
- self.threshold = threshold
550
- self.sample_rate = sample_rate
551
- self.dtype = dtype
552
- self.name = "fcpe"
553
-
554
- def repeat_expand(self, content: Union[torch.Tensor, np.ndarray], target_len, mode = "nearest"):
555
- ndim = content.ndim
556
- content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
557
- assert content.ndim == 3
558
- is_np = isinstance(content, np.ndarray)
559
- content = torch.from_numpy(content) if is_np else content
560
- results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
561
- results = results.numpy() if is_np else results
562
- return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
563
-
564
- def post_process(self, x, sample_rate, f0, pad_to):
565
- f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
566
- f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
567
-
568
- vuv_vector = torch.zeros_like(f0)
569
- vuv_vector[f0 > 0.0] = 1.0
570
- vuv_vector[f0 <= 0.0] = 0.0
571
-
572
- nzindex = torch.nonzero(f0).squeeze()
573
- f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
574
- time_org = self.hop_length / sample_rate * nzindex.cpu().numpy()
575
- time_frame = np.arange(pad_to) * self.hop_length / sample_rate
576
-
577
- vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
578
-
579
- if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
580
- if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
581
-
582
- f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
583
- return f0, vuv_vector.cpu().numpy()
584
-
585
- def compute_f0(self, wav, p_len=None):
586
- x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
587
- p_len = x.shape[0] // self.hop_length if p_len is None else p_len
588
- f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
589
-
590
- if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
591
-
592
- return self.post_process(x, self.sample_rate, f0, p_len)[0]
593
- def compute_f0_uv(self, wav, p_len=None):
594
- x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
595
- p_len = x.shape[0] // self.hop_length if p_len is None else p_len
596
- f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
597
-
598
- if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
599
-
600
- return self.post_process(x, self.sample_rate, f0, p_len)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/predictors/RMVPE.py DELETED
@@ -1,270 +0,0 @@
1
- import torch
2
- import numpy as np
3
- import torch.nn as nn
4
- import torch.nn.functional as F
5
-
6
- from typing import List
7
- from librosa.filters import mel
8
-
9
- N_MELS = 128
10
- N_CLASS = 360
11
-
12
-
13
- class ConvBlockRes(nn.Module):
14
- def __init__(self, in_channels, out_channels, momentum=0.01):
15
- super(ConvBlockRes, self).__init__()
16
- self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
17
-
18
- if in_channels != out_channels:
19
- self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
20
- self.is_shortcut = True
21
- else: self.is_shortcut = False
22
-
23
- def forward(self, x):
24
- if self.is_shortcut: return self.conv(x) + self.shortcut(x)
25
- else: return self.conv(x) + x
26
-
27
- class ResEncoderBlock(nn.Module):
28
- def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
29
- super(ResEncoderBlock, self).__init__()
30
- self.n_blocks = n_blocks
31
- self.conv = nn.ModuleList()
32
- self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
33
-
34
- for _ in range(n_blocks - 1):
35
- self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
36
-
37
- self.kernel_size = kernel_size
38
-
39
- if self.kernel_size is not None:
40
- self.pool = nn.AvgPool2d(kernel_size=kernel_size)
41
-
42
- def forward(self, x):
43
- for i in range(self.n_blocks):
44
- x = self.conv[i](x)
45
-
46
- if self.kernel_size is not None: return x, self.pool(x)
47
- else: return x
48
-
49
- class Encoder(nn.Module):
50
- def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
51
- super(Encoder, self).__init__()
52
- self.n_encoders = n_encoders
53
- self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
54
- self.layers = nn.ModuleList()
55
- self.latent_channels = []
56
-
57
- for i in range(self.n_encoders):
58
- self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
59
- self.latent_channels.append([out_channels, in_size])
60
- in_channels = out_channels
61
- out_channels *= 2
62
- in_size //= 2
63
-
64
- self.out_size = in_size
65
- self.out_channel = out_channels
66
-
67
- def forward(self, x: torch.Tensor):
68
- concat_tensors: List[torch.Tensor] = []
69
- x = self.bn(x)
70
-
71
- for i in range(self.n_encoders):
72
- t, x = self.layers[i](x)
73
- concat_tensors.append(t)
74
-
75
- return x, concat_tensors
76
-
77
- class Intermediate(nn.Module):
78
- def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
79
- super(Intermediate, self).__init__()
80
- self.n_inters = n_inters
81
- self.layers = nn.ModuleList()
82
- self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
83
-
84
- for _ in range(self.n_inters - 1):
85
- self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
86
-
87
- def forward(self, x):
88
- for i in range(self.n_inters):
89
- x = self.layers[i](x)
90
- return x
91
-
92
- class ResDecoderBlock(nn.Module):
93
- def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
94
- super(ResDecoderBlock, self).__init__()
95
- out_padding = (0, 1) if stride == (1, 2) else (1, 1)
96
- self.n_blocks = n_blocks
97
- self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
98
- self.conv2 = nn.ModuleList()
99
- self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
100
- for _ in range(n_blocks - 1):
101
- self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
102
-
103
- def forward(self, x, concat_tensor):
104
- x = self.conv1(x)
105
- x = torch.cat((x, concat_tensor), dim=1)
106
- for i in range(self.n_blocks):
107
- x = self.conv2[i](x)
108
-
109
- return x
110
-
111
- class Decoder(nn.Module):
112
- def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
113
- super(Decoder, self).__init__()
114
- self.layers = nn.ModuleList()
115
- self.n_decoders = n_decoders
116
-
117
- for _ in range(self.n_decoders):
118
- out_channels = in_channels // 2
119
- self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
120
- in_channels = out_channels
121
-
122
- def forward(self, x, concat_tensors):
123
- for i in range(self.n_decoders):
124
- x = self.layers[i](x, concat_tensors[-1 - i])
125
-
126
- return x
127
-
128
- class DeepUnet(nn.Module):
129
- def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
130
- super(DeepUnet, self).__init__()
131
- self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
132
- self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
133
- self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
134
-
135
- def forward(self, x):
136
- x, concat_tensors = self.encoder(x)
137
- x = self.intermediate(x)
138
- x = self.decoder(x, concat_tensors)
139
- return x
140
-
141
- class E2E(nn.Module):
142
- def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
143
- super(E2E, self).__init__()
144
- self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
145
- self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
146
-
147
- if n_gru: self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
148
- else: self.fc = nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
149
-
150
- def forward(self, mel):
151
- mel = mel.transpose(-1, -2).unsqueeze(1)
152
- x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
153
- x = self.fc(x)
154
- return x
155
-
156
- class MelSpectrogram(torch.nn.Module):
157
- def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
158
- super().__init__()
159
- n_fft = win_length if n_fft is None else n_fft
160
- self.hann_window = {}
161
- mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
162
- mel_basis = torch.from_numpy(mel_basis).float()
163
- self.register_buffer("mel_basis", mel_basis)
164
- self.n_fft = win_length if n_fft is None else n_fft
165
- self.hop_length = hop_length
166
- self.win_length = win_length
167
- self.sample_rate = sample_rate
168
- self.n_mel_channels = n_mel_channels
169
- self.clamp = clamp
170
- self.is_half = is_half
171
-
172
- def forward(self, audio, keyshift=0, speed=1, center=True):
173
- factor = 2 ** (keyshift / 12)
174
- n_fft_new = int(np.round(self.n_fft * factor))
175
- win_length_new = int(np.round(self.win_length * factor))
176
- hop_length_new = int(np.round(self.hop_length * speed))
177
- keyshift_key = str(keyshift) + "_" + str(audio.device)
178
-
179
- if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
180
-
181
- fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
182
- magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
183
-
184
- if keyshift != 0:
185
- size = self.n_fft // 2 + 1
186
- resize = magnitude.size(1)
187
-
188
- if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
189
-
190
- magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
191
-
192
- mel_output = torch.matmul(self.mel_basis, magnitude)
193
-
194
- if self.is_half: mel_output = mel_output.half()
195
-
196
- log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
197
- return log_mel_spec
198
-
199
- class RMVPE:
200
- def __init__(self, model_path, is_half, device=None):
201
- self.resample_kernel = {}
202
- model = E2E(4, 1, (2, 2))
203
- ckpt = torch.load(model_path, map_location="cpu")
204
- model.load_state_dict(ckpt)
205
- model.eval()
206
-
207
- if is_half: model = model.half()
208
-
209
- self.model = model
210
- self.resample_kernel = {}
211
- self.is_half = is_half
212
- self.device = device
213
- self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
214
- self.model = self.model.to(device)
215
- cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
216
- self.cents_mapping = np.pad(cents_mapping, (4, 4))
217
-
218
- def mel2hidden(self, mel):
219
- with torch.no_grad():
220
- n_frames = mel.shape[-1]
221
- mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
222
- hidden = self.model(mel)
223
- return hidden[:, :n_frames]
224
-
225
- def decode(self, hidden, thred=0.03):
226
- cents_pred = self.to_local_average_cents(hidden, thred=thred)
227
- f0 = 10 * (2 ** (cents_pred / 1200))
228
- f0[f0 == 10] = 0
229
- return f0
230
-
231
- def infer_from_audio(self, audio, thred=0.03):
232
- audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
233
- mel = self.mel_extractor(audio, center=True)
234
- hidden = self.mel2hidden(mel)
235
- hidden = hidden.squeeze(0).cpu().numpy()
236
-
237
- if self.is_half: hidden = hidden.astype("float32")
238
-
239
- f0 = self.decode(hidden, thred=thred)
240
- return f0
241
-
242
- def to_local_average_cents(self, salience, thred=0.05):
243
- center = np.argmax(salience, axis=1)
244
- salience = np.pad(salience, ((0, 0), (4, 4)))
245
- center += 4
246
- todo_salience = []
247
- todo_cents_mapping = []
248
- starts = center - 4
249
- ends = center + 5
250
-
251
- for idx in range(salience.shape[0]):
252
- todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
253
- todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
254
-
255
- todo_salience = np.array(todo_salience)
256
- todo_cents_mapping = np.array(todo_cents_mapping)
257
- product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
258
- weight_sum = np.sum(todo_salience, 1)
259
- devided = product_sum / weight_sum
260
- maxx = np.max(salience, axis=1)
261
- devided[maxx <= thred] = 0
262
- return devided
263
-
264
- class BiGRU(nn.Module):
265
- def __init__(self, input_features, hidden_features, num_layers):
266
- super(BiGRU, self).__init__()
267
- self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
268
-
269
- def forward(self, x):
270
- return self.gru(x)[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/common_separator.py DELETED
@@ -1,270 +0,0 @@
1
- import os
2
- import gc
3
- import sys
4
- import torch
5
- import librosa
6
-
7
- import numpy as np
8
- import soundfile as sf
9
-
10
- from logging import Logger
11
- from pydub import AudioSegment
12
-
13
- now_dir = os.getcwd()
14
- sys.path.append(now_dir)
15
-
16
- from . import spec_utils
17
- from main.configs.config import Config
18
-
19
- translations = Config().translations
20
-
21
- class CommonSeparator:
22
- ALL_STEMS = "All Stems"
23
- VOCAL_STEM = "Vocals"
24
- INST_STEM = "Instrumental"
25
- OTHER_STEM = "Other"
26
- BASS_STEM = "Bass"
27
- DRUM_STEM = "Drums"
28
- GUITAR_STEM = "Guitar"
29
- PIANO_STEM = "Piano"
30
- SYNTH_STEM = "Synthesizer"
31
- STRINGS_STEM = "Strings"
32
- WOODWINDS_STEM = "Woodwinds"
33
- BRASS_STEM = "Brass"
34
- WIND_INST_STEM = "Wind Inst"
35
- NO_OTHER_STEM = "No Other"
36
- NO_BASS_STEM = "No Bass"
37
- NO_DRUM_STEM = "No Drums"
38
- NO_GUITAR_STEM = "No Guitar"
39
- NO_PIANO_STEM = "No Piano"
40
- NO_SYNTH_STEM = "No Synthesizer"
41
- NO_STRINGS_STEM = "No Strings"
42
- NO_WOODWINDS_STEM = "No Woodwinds"
43
- NO_WIND_INST_STEM = "No Wind Inst"
44
- NO_BRASS_STEM = "No Brass"
45
- PRIMARY_STEM = "Primary Stem"
46
- SECONDARY_STEM = "Secondary Stem"
47
- LEAD_VOCAL_STEM = "lead_only"
48
- BV_VOCAL_STEM = "backing_only"
49
- LEAD_VOCAL_STEM_I = "with_lead_vocals"
50
- BV_VOCAL_STEM_I = "with_backing_vocals"
51
- LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
52
- BV_VOCAL_STEM_LABEL = "Backing Vocals"
53
- NO_STEM = "No "
54
-
55
- STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
56
-
57
- NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
58
-
59
-
60
- def __init__(self, config):
61
- self.logger: Logger = config.get("logger")
62
- self.log_level: int = config.get("log_level")
63
- self.torch_device = config.get("torch_device")
64
- self.torch_device_cpu = config.get("torch_device_cpu")
65
- self.torch_device_mps = config.get("torch_device_mps")
66
- self.onnx_execution_provider = config.get("onnx_execution_provider")
67
- self.model_name = config.get("model_name")
68
- self.model_path = config.get("model_path")
69
- self.model_data = config.get("model_data")
70
- self.output_dir = config.get("output_dir")
71
- self.output_format = config.get("output_format")
72
- self.output_bitrate = config.get("output_bitrate")
73
- self.normalization_threshold = config.get("normalization_threshold")
74
- self.enable_denoise = config.get("enable_denoise")
75
- self.output_single_stem = config.get("output_single_stem")
76
- self.invert_using_spec = config.get("invert_using_spec")
77
- self.sample_rate = config.get("sample_rate")
78
-
79
- self.primary_stem_name = None
80
- self.secondary_stem_name = None
81
-
82
- if "training" in self.model_data and "instruments" in self.model_data["training"]:
83
- instruments = self.model_data["training"]["instruments"]
84
-
85
- if instruments:
86
- self.primary_stem_name = instruments[0]
87
- self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
88
-
89
- if self.primary_stem_name is None:
90
- self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
91
- self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
92
-
93
- self.is_karaoke = self.model_data.get("is_karaoke", False)
94
- self.is_bv_model = self.model_data.get("is_bv_model", False)
95
- self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
96
-
97
- self.logger.debug(translations["info"].format(model_name=self.model_name, model_path=self.model_path))
98
- self.logger.debug(translations["info_2"].format(output_dir=self.output_dir, output_format=self.output_format))
99
- self.logger.debug(translations["info_3"].format(normalization_threshold=self.normalization_threshold))
100
- self.logger.debug(translations["info_4"].format(enable_denoise=self.enable_denoise, output_single_stem=self.output_single_stem))
101
- self.logger.debug(translations["info_5"].format(invert_using_spec=self.invert_using_spec, sample_rate=self.sample_rate))
102
- self.logger.debug(translations["info_6"].format(primary_stem_name=self.primary_stem_name, secondary_stem_name=self.secondary_stem_name))
103
- self.logger.debug(translations["info_7"].format(is_karaoke=self.is_karaoke, is_bv_model=self.is_bv_model, bv_model_rebalance=self.bv_model_rebalance))
104
-
105
- self.audio_file_path = None
106
- self.audio_file_base = None
107
- self.primary_source = None
108
- self.secondary_source = None
109
- self.primary_stem_output_path = None
110
- self.secondary_stem_output_path = None
111
- self.cached_sources_map = {}
112
-
113
- def secondary_stem(self, primary_stem: str):
114
- primary_stem = primary_stem if primary_stem else self.NO_STEM
115
-
116
- return self.STEM_PAIR_MAPPER[primary_stem] if primary_stem in self.STEM_PAIR_MAPPER else primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
117
-
118
- def separate(self, audio_file_path):
119
- pass
120
-
121
- def final_process(self, stem_path, source, stem_name):
122
- self.logger.debug(translations["success_process"].format(stem_name=stem_name))
123
- self.write_audio(stem_path, source)
124
-
125
- return {stem_name: source}
126
-
127
- def cached_sources_clear(self):
128
- self.cached_sources_map = {}
129
-
130
- def cached_source_callback(self, model_architecture, model_name=None):
131
- model, sources = None, None
132
- mapper = self.cached_sources_map[model_architecture]
133
-
134
- for key, value in mapper.items():
135
- if model_name in key:
136
- model = key
137
- sources = value
138
-
139
- return model, sources
140
-
141
- def cached_model_source_holder(self, model_architecture, sources, model_name=None):
142
- self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
143
-
144
- def prepare_mix(self, mix):
145
- audio_path = mix
146
-
147
- if not isinstance(mix, np.ndarray):
148
- self.logger.debug(f"{translations['load_audio']}: {mix}")
149
- mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
150
- self.logger.debug(translations["load_audio_success"].format(sr=sr, shape=mix.shape))
151
- else:
152
- self.logger.debug(translations["convert_mix"])
153
- mix = mix.T
154
- self.logger.debug(translations["convert_shape"].format(shape=mix.shape))
155
-
156
- if isinstance(audio_path, str):
157
- if not np.any(mix):
158
- error_msg = translations["audio_not_valid"].format(audio_path=audio_path)
159
- self.logger.error(error_msg)
160
- raise ValueError(error_msg)
161
- else: self.logger.debug(translations["audio_valid"])
162
-
163
- if mix.ndim == 1:
164
- self.logger.debug(translations["mix_single"])
165
- mix = np.asfortranarray([mix, mix])
166
- self.logger.debug(translations["convert_mix_audio"])
167
-
168
- self.logger.debug(translations["mix_success_2"])
169
- return mix
170
-
171
- def write_audio(self, stem_path: str, stem_source):
172
- duration_seconds = librosa.get_duration(filename=self.audio_file_path)
173
- duration_hours = duration_seconds / 3600
174
- self.logger.info(translations["duration"].format(duration_hours=f"{duration_hours:.2f}", duration_seconds=f"{duration_seconds:.2f}"))
175
-
176
- if duration_hours >= 1:
177
- self.logger.warning(translations["write"].format(name="soundfile"))
178
- self.write_audio_soundfile(stem_path, stem_source)
179
- else:
180
- self.logger.info(translations["write"].format(name="pydub"))
181
- self.write_audio_pydub(stem_path, stem_source)
182
-
183
- def write_audio_pydub(self, stem_path: str, stem_source):
184
- self.logger.debug(f"{translations['write_audio'].format(name='write_audio_pydub')} {stem_path}")
185
-
186
- stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold)
187
-
188
- if np.max(np.abs(stem_source)) < 1e-6:
189
- self.logger.warning(translations["original_not_valid"])
190
- return
191
-
192
- if self.output_dir:
193
- os.makedirs(self.output_dir, exist_ok=True)
194
- stem_path = os.path.join(self.output_dir, stem_path)
195
-
196
- self.logger.debug(f"{translations['shape_audio']}: {stem_source.shape}")
197
- self.logger.debug(f"{translations['convert_data']}: {stem_source.dtype}")
198
-
199
- if stem_source.dtype != np.int16:
200
- stem_source = (stem_source * 32767).astype(np.int16)
201
- self.logger.debug(translations["original_source_to_int16"])
202
-
203
- stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
204
- stem_source_interleaved[0::2] = stem_source[:, 0]
205
- stem_source_interleaved[1::2] = stem_source[:, 1]
206
-
207
- self.logger.debug(f"{translations['shape_audio_2']}: {stem_source_interleaved.shape}")
208
-
209
- try:
210
- audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
211
- self.logger.debug(translations["create_audiosegment"])
212
- except (IOError, ValueError) as e:
213
- self.logger.error(f"{translations['create_audiosegment_error']}: {e}")
214
- return
215
-
216
- file_format = stem_path.lower().split(".")[-1]
217
-
218
- if file_format == "m4a": file_format = "mp4"
219
- elif file_format == "mka": file_format = "matroska"
220
-
221
- bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
222
-
223
- try:
224
- audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
225
- self.logger.debug(f"{translations['export_success']} {stem_path}")
226
- except (IOError, ValueError) as e:
227
- self.logger.error(f"{translations['export_error']}: {e}")
228
-
229
- def write_audio_soundfile(self, stem_path: str, stem_source):
230
- self.logger.debug(f"{translations['write_audio'].format(name='write_audio_soundfile')}: {stem_path}")
231
-
232
- if stem_source.shape[1] == 2:
233
- if stem_source.flags["F_CONTIGUOUS"]: stem_source = np.ascontiguousarray(stem_source)
234
- else:
235
- stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
236
- stereo_interleaved[0::2] = stem_source[:, 0]
237
-
238
- stereo_interleaved[1::2] = stem_source[:, 1]
239
- stem_source = stereo_interleaved
240
-
241
- self.logger.debug(f"{translations['shape_audio_2']}: {stem_source.shape}")
242
-
243
- try:
244
- sf.write(stem_path, stem_source, self.sample_rate)
245
- self.logger.debug(f"{translations['export_success']} {stem_path}")
246
- except Exception as e:
247
- self.logger.error(f"{translations['export_error']}: {e}")
248
-
249
- def clear_gpu_cache(self):
250
- self.logger.debug(translations["clean"])
251
- gc.collect()
252
-
253
- if self.torch_device == torch.device("mps"):
254
- self.logger.debug(translations["clean_cache"].format(name="MPS"))
255
- torch.mps.empty_cache()
256
-
257
- if self.torch_device == torch.device("cuda"):
258
- self.logger.debug(translations["clean_cache"].format(name="CUDA"))
259
- torch.cuda.empty_cache()
260
-
261
- def clear_file_specific_paths(self):
262
- self.logger.info(translations["del_path"])
263
- self.audio_file_path = None
264
- self.audio_file_base = None
265
-
266
- self.primary_source = None
267
- self.secondary_source = None
268
-
269
- self.primary_stem_output_path = None
270
- self.secondary_stem_output_path = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/apply.py DELETED
@@ -1,280 +0,0 @@
1
- import tqdm
2
- import torch
3
- import random
4
-
5
- import typing as tp
6
-
7
- from torch import nn
8
- from torch.nn import functional as F
9
- from concurrent.futures import ThreadPoolExecutor
10
-
11
- from .demucs import Demucs
12
- from .hdemucs import HDemucs
13
- from .htdemucs import HTDemucs
14
- from .utils import center_trim
15
-
16
-
17
- Model = tp.Union[Demucs, HDemucs, HTDemucs]
18
-
19
- class DummyPoolExecutor:
20
- class DummyResult:
21
- def __init__(self, func, *args, **kwargs):
22
- self.func = func
23
- self.args = args
24
- self.kwargs = kwargs
25
-
26
- def result(self):
27
- return self.func(*self.args, **self.kwargs)
28
-
29
- def __init__(self, workers=0):
30
- pass
31
-
32
- def submit(self, func, *args, **kwargs):
33
- return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
34
-
35
- def __enter__(self):
36
- return self
37
-
38
- def __exit__(self, exc_type, exc_value, exc_tb):
39
- return
40
-
41
- class BagOfModels(nn.Module):
42
- def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None):
43
- super().__init__()
44
- assert len(models) > 0
45
- first = models[0]
46
-
47
- for other in models:
48
- assert other.sources == first.sources
49
- assert other.samplerate == first.samplerate
50
- assert other.audio_channels == first.audio_channels
51
-
52
- if segment is not None: other.segment = segment
53
-
54
- self.audio_channels = first.audio_channels
55
- self.samplerate = first.samplerate
56
- self.sources = first.sources
57
- self.models = nn.ModuleList(models)
58
-
59
- if weights is None: weights = [[1.0 for _ in first.sources] for _ in models]
60
- else:
61
- assert len(weights) == len(models)
62
-
63
- for weight in weights:
64
- assert len(weight) == len(first.sources)
65
-
66
- self.weights = weights
67
-
68
- def forward(self, x):
69
- raise NotImplementedError("`apply_model`")
70
-
71
-
72
- class TensorChunk:
73
- def __init__(self, tensor, offset=0, length=None):
74
- total_length = tensor.shape[-1]
75
- assert offset >= 0
76
- assert offset < total_length
77
-
78
- length = total_length - offset if length is None else min(total_length - offset, length)
79
-
80
- if isinstance(tensor, TensorChunk):
81
- self.tensor = tensor.tensor
82
- self.offset = offset + tensor.offset
83
- else:
84
- self.tensor = tensor
85
- self.offset = offset
86
-
87
- self.length = length
88
- self.device = tensor.device
89
-
90
- @property
91
- def shape(self):
92
- shape = list(self.tensor.shape)
93
- shape[-1] = self.length
94
-
95
- return shape
96
-
97
- def padded(self, target_length):
98
- delta = target_length - self.length
99
- total_length = self.tensor.shape[-1]
100
- assert delta >= 0
101
-
102
- start = self.offset - delta // 2
103
- end = start + target_length
104
-
105
- correct_start = max(0, start)
106
- correct_end = min(total_length, end)
107
-
108
- pad_left = correct_start - start
109
- pad_right = end - correct_end
110
-
111
- out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
112
-
113
- assert out.shape[-1] == target_length
114
-
115
- return out
116
-
117
-
118
- def tensor_chunk(tensor_or_chunk):
119
- if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk
120
- else:
121
- assert isinstance(tensor_or_chunk, torch.Tensor)
122
-
123
- return TensorChunk(tensor_or_chunk)
124
-
125
-
126
- def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
127
- global fut_length
128
- global bag_num
129
- global prog_bar
130
-
131
- device = mix.device if device is None else torch.device(device)
132
-
133
- if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor()
134
-
135
- kwargs = {
136
- "shifts": shifts,
137
- "split": split,
138
- "overlap": overlap,
139
- "transition_power": transition_power,
140
- "progress": progress,
141
- "device": device,
142
- "pool": pool,
143
- "set_progress_bar": set_progress_bar,
144
- "static_shifts": static_shifts,
145
- }
146
-
147
- if isinstance(model, BagOfModels):
148
- estimates = 0
149
- totals = [0] * len(model.sources)
150
- bag_num = len(model.models)
151
- fut_length = 0
152
- prog_bar = 0
153
- current_model = 0
154
-
155
- for sub_model, weight in zip(model.models, model.weights):
156
- original_model_device = next(iter(sub_model.parameters())).device
157
- sub_model.to(device)
158
- fut_length += fut_length
159
- current_model += 1
160
-
161
- out = apply_model(sub_model, mix, **kwargs)
162
- sub_model.to(original_model_device)
163
-
164
- for k, inst_weight in enumerate(weight):
165
- out[:, k, :, :] *= inst_weight
166
- totals[k] += inst_weight
167
-
168
- estimates += out
169
- del out
170
-
171
- for k in range(estimates.shape[1]):
172
- estimates[:, k, :, :] /= totals[k]
173
-
174
- return estimates
175
-
176
- model.to(device)
177
- model.eval()
178
-
179
- assert transition_power >= 1
180
-
181
- batch, channels, length = mix.shape
182
-
183
- if shifts:
184
- kwargs["shifts"] = 0
185
- max_shift = int(0.5 * model.samplerate)
186
- mix = tensor_chunk(mix)
187
- padded_mix = mix.padded(length + 2 * max_shift)
188
- out = 0
189
-
190
- for _ in range(shifts):
191
- offset = random.randint(0, max_shift)
192
- shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
193
- shifted_out = apply_model(model, shifted, **kwargs)
194
- out += shifted_out[..., max_shift - offset :]
195
-
196
- out /= shifts
197
-
198
- return out
199
- elif split:
200
- kwargs["split"] = False
201
- out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device)
202
- sum_weight = torch.zeros(length, device=mix.device)
203
- segment = int(model.samplerate * model.segment)
204
- stride = int((1 - overlap) * segment)
205
- offsets = range(0, length, stride)
206
-
207
- weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)])
208
- assert len(weight) == segment
209
-
210
- weight = (weight / weight.max()) ** transition_power
211
- futures = []
212
-
213
- for offset in offsets:
214
- chunk = TensorChunk(mix, offset, segment)
215
- future = pool.submit(apply_model, model, chunk, **kwargs)
216
- futures.append((future, offset))
217
- offset += segment
218
-
219
- if progress: futures = tqdm.tqdm(futures)
220
-
221
- for future, offset in futures:
222
- if set_progress_bar:
223
- fut_length = len(futures) * bag_num * static_shifts
224
- prog_bar += 1
225
- set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
226
-
227
- chunk_out = future.result()
228
- chunk_length = chunk_out.shape[-1]
229
-
230
- out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
231
- sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
232
-
233
- assert sum_weight.min() > 0
234
-
235
- out /= sum_weight
236
-
237
- return out
238
- else:
239
- valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length
240
-
241
- mix = tensor_chunk(mix)
242
- padded_mix = mix.padded(valid_length).to(device)
243
-
244
- with torch.no_grad():
245
- out = model(padded_mix)
246
-
247
- return center_trim(out, length)
248
-
249
-
250
- def demucs_segments(demucs_segment, demucs_model):
251
- if demucs_segment == "Default":
252
- segment = None
253
-
254
- if isinstance(demucs_model, BagOfModels):
255
- if segment is not None:
256
- for sub in demucs_model.models:
257
- sub.segment = segment
258
- else:
259
- if segment is not None: sub.segment = segment
260
- else:
261
- try:
262
- segment = int(demucs_segment)
263
-
264
- if isinstance(demucs_model, BagOfModels):
265
- if segment is not None:
266
- for sub in demucs_model.models:
267
- sub.segment = segment
268
- else:
269
- if segment is not None: sub.segment = segment
270
- except:
271
- segment = None
272
-
273
- if isinstance(demucs_model, BagOfModels):
274
- if segment is not None:
275
- for sub in demucs_model.models:
276
- sub.segment = segment
277
- else:
278
- if segment is not None: sub.segment = segment
279
-
280
- return demucs_model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/demucs.py DELETED
@@ -1,340 +0,0 @@
1
- import math
2
- import torch
3
- import julius
4
-
5
- import typing as tp
6
-
7
- from torch import nn
8
-
9
- from torch.nn import functional as F
10
-
11
- from .utils import center_trim
12
- from .states import capture_init
13
-
14
-
15
- def unfold(a, kernel_size, stride):
16
- *shape, length = a.shape
17
- n_frames = math.ceil(length / stride)
18
-
19
- tgt_length = (n_frames - 1) * stride + kernel_size
20
- a = F.pad(a, (0, tgt_length - length))
21
- strides = list(a.stride())
22
-
23
- assert strides[-1] == 1
24
-
25
- strides = strides[:-1] + [stride, 1]
26
-
27
- return a.as_strided([*shape, n_frames, kernel_size], strides)
28
-
29
- def rescale_conv(conv, reference):
30
- scale = (conv.weight.std().detach() / reference) ** 0.5
31
- conv.weight.data /= scale
32
-
33
- if conv.bias is not None: conv.bias.data /= scale
34
-
35
- def rescale_module(module, reference):
36
- for sub in module.modules():
37
- if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): rescale_conv(sub, reference)
38
-
39
- class BLSTM(nn.Module):
40
- def __init__(self, dim, layers=1, max_steps=None, skip=False):
41
- super().__init__()
42
- assert max_steps is None or max_steps % 4 == 0
43
- self.max_steps = max_steps
44
- self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
45
- self.linear = nn.Linear(2 * dim, dim)
46
- self.skip = skip
47
-
48
- def forward(self, x):
49
- B, C, T = x.shape
50
- y = x
51
- framed = False
52
-
53
- if self.max_steps is not None and T > self.max_steps:
54
- width = self.max_steps
55
- stride = width // 2
56
- frames = unfold(x, width, stride)
57
- nframes = frames.shape[2]
58
- framed = True
59
- x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
60
-
61
- x = x.permute(2, 0, 1)
62
-
63
- x = self.lstm(x)[0]
64
- x = self.linear(x)
65
- x = x.permute(1, 2, 0)
66
-
67
- if framed:
68
- out = []
69
- frames = x.reshape(B, -1, C, width)
70
- limit = stride // 2
71
-
72
- for k in range(nframes):
73
- if k == 0: out.append(frames[:, k, :, :-limit])
74
- elif k == nframes - 1: out.append(frames[:, k, :, limit:])
75
- else: out.append(frames[:, k, :, limit:-limit])
76
-
77
- out = torch.cat(out, -1)
78
- out = out[..., :T]
79
- x = out
80
-
81
- if self.skip: x = x + y
82
-
83
- return x
84
-
85
- class LayerScale(nn.Module):
86
- def __init__(self, channels: int, init: float = 0):
87
- super().__init__()
88
- self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
89
- self.scale.data[:] = init
90
-
91
- def forward(self, x):
92
- return self.scale[:, None] * x
93
-
94
- class DConv(nn.Module):
95
- def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
96
- super().__init__()
97
- assert kernel % 2 == 1
98
- self.channels = channels
99
- self.compress = compress
100
- self.depth = abs(depth)
101
- dilate = depth > 0
102
-
103
- norm_fn: tp.Callable[[int], nn.Module]
104
- norm_fn = lambda d: nn.Identity()
105
-
106
- if norm: norm_fn = lambda d: nn.GroupNorm(1, d)
107
-
108
- hidden = int(channels / compress)
109
-
110
- act: tp.Type[nn.Module]
111
- act = nn.GELU if gelu else nn.ReLU
112
-
113
- self.layers = nn.ModuleList([])
114
-
115
- for d in range(self.depth):
116
- dilation = 2**d if dilate else 1
117
- padding = dilation * (kernel // 2)
118
-
119
- mods = [
120
- nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
121
- norm_fn(hidden),
122
- act(),
123
- nn.Conv1d(hidden, 2 * channels, 1),
124
- norm_fn(2 * channels),
125
- nn.GLU(1),
126
- LayerScale(channels, init),
127
- ]
128
-
129
- if attn: mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
130
- if lstm: mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
131
-
132
- layer = nn.Sequential(*mods)
133
- self.layers.append(layer)
134
-
135
- def forward(self, x):
136
- for layer in self.layers:
137
- x = x + layer(x)
138
-
139
- return x
140
-
141
- class LocalState(nn.Module):
142
- def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
143
- super().__init__()
144
-
145
- assert channels % heads == 0, (channels, heads)
146
-
147
- self.heads = heads
148
- self.nfreqs = nfreqs
149
- self.ndecay = ndecay
150
- self.content = nn.Conv1d(channels, channels, 1)
151
- self.query = nn.Conv1d(channels, channels, 1)
152
- self.key = nn.Conv1d(channels, channels, 1)
153
-
154
- if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
155
-
156
- if ndecay:
157
- self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
158
- self.query_decay.weight.data *= 0.01
159
-
160
- assert self.query_decay.bias is not None
161
-
162
- self.query_decay.bias.data[:] = -2
163
-
164
- self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
165
-
166
- def forward(self, x):
167
- B, C, T = x.shape
168
- heads = self.heads
169
- indexes = torch.arange(T, device=x.device, dtype=x.dtype)
170
- delta = indexes[:, None] - indexes[None, :]
171
-
172
- queries = self.query(x).view(B, heads, -1, T)
173
- keys = self.key(x).view(B, heads, -1, T)
174
-
175
- dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
176
- dots /= keys.shape[2] ** 0.5
177
-
178
-
179
- if self.nfreqs:
180
- periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
181
- freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
182
- freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
183
- dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
184
-
185
- if self.ndecay:
186
- decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
187
- decay_q = self.query_decay(x).view(B, heads, -1, T)
188
- decay_q = torch.sigmoid(decay_q) / 2
189
- decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
190
- dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
191
-
192
-
193
- dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
194
- weights = torch.softmax(dots, dim=2)
195
-
196
- content = self.content(x).view(B, heads, -1, T)
197
- result = torch.einsum("bhts,bhct->bhcs", weights, content)
198
-
199
- if self.nfreqs:
200
- time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
201
- result = torch.cat([result, time_sig], 2)
202
-
203
- result = result.reshape(B, -1, T)
204
- return x + self.proj(result)
205
-
206
- class Demucs(nn.Module):
207
- @capture_init
208
- def __init__(self, sources, audio_channels=2, channels=64, growth=2.0, depth=6, rewrite=True, lstm_layers=0, kernel_size=8, stride=4, context=1, gelu=True, glu=True, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, normalize=True, resample=True, rescale=0.1, samplerate=44100, segment=4 * 10):
209
- super().__init__()
210
- self.audio_channels = audio_channels
211
- self.sources = sources
212
- self.kernel_size = kernel_size
213
- self.context = context
214
- self.stride = stride
215
- self.depth = depth
216
- self.resample = resample
217
- self.channels = channels
218
- self.normalize = normalize
219
- self.samplerate = samplerate
220
- self.segment = segment
221
-
222
- self.encoder = nn.ModuleList()
223
- self.decoder = nn.ModuleList()
224
- self.skip_scales = nn.ModuleList()
225
-
226
- if glu:
227
- activation = nn.GLU(dim=1)
228
- ch_scale = 2
229
- else:
230
- activation = nn.ReLU()
231
- ch_scale = 1
232
-
233
-
234
- act2 = nn.GELU if gelu else nn.ReLU
235
-
236
- in_channels = audio_channels
237
- padding = 0
238
-
239
-
240
- for index in range(depth):
241
- norm_fn = lambda d: nn.Identity()
242
-
243
- if index >= norm_starts: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
244
-
245
- encode = []
246
- encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
247
-
248
- attn = index >= dconv_attn
249
- lstm = index >= dconv_lstm
250
-
251
- if dconv_mode & 1: encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
252
- if rewrite: encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
253
-
254
- self.encoder.append(nn.Sequential(*encode))
255
-
256
- decode = []
257
-
258
- out_channels = in_channels if index > 0 else len(self.sources) * audio_channels
259
-
260
- if rewrite: decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
261
- if dconv_mode & 2: decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
262
-
263
- decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
264
-
265
- if index > 0: decode += [norm_fn(out_channels), act2()]
266
-
267
- self.decoder.insert(0, nn.Sequential(*decode))
268
- in_channels = channels
269
- channels = int(growth * channels)
270
-
271
-
272
- channels = in_channels
273
-
274
- self.lstm = BLSTM(channels, lstm_layers) if lstm_layers else None
275
-
276
-
277
- if rescale: rescale_module(self, reference=rescale)
278
-
279
- def valid_length(self, length):
280
- if self.resample: length *= 2
281
-
282
- for _ in range(self.depth):
283
- length = math.ceil((length - self.kernel_size) / self.stride) + 1
284
- length = max(1, length)
285
-
286
- for _ in range(self.depth):
287
- length = (length - 1) * self.stride + self.kernel_size
288
-
289
- if self.resample: length = math.ceil(length / 2)
290
-
291
- return int(length)
292
-
293
- def forward(self, mix):
294
- x = mix
295
- length = x.shape[-1]
296
-
297
- if self.normalize:
298
- mono = mix.mean(dim=1, keepdim=True)
299
- mean = mono.mean(dim=-1, keepdim=True)
300
- std = mono.std(dim=-1, keepdim=True)
301
- x = (x - mean) / (1e-5 + std)
302
- else:
303
- mean = 0
304
- std = 1
305
-
306
- delta = self.valid_length(length) - length
307
- x = F.pad(x, (delta // 2, delta - delta // 2))
308
-
309
- if self.resample: x = julius.resample_frac(x, 1, 2)
310
-
311
- saved = []
312
-
313
- for encode in self.encoder:
314
- x = encode(x)
315
- saved.append(x)
316
-
317
- if self.lstm: x = self.lstm(x)
318
-
319
- for decode in self.decoder:
320
- skip = saved.pop(-1)
321
- skip = center_trim(skip, x)
322
- x = decode(x + skip)
323
-
324
- if self.resample: x = julius.resample_frac(x, 2, 1)
325
-
326
- x = x * std + mean
327
- x = center_trim(x, length)
328
- x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
329
-
330
- return x
331
-
332
- def load_state_dict(self, state, strict=True):
333
- for idx in range(self.depth):
334
- for a in ["encoder", "decoder"]:
335
- for b in ["bias", "weight"]:
336
- new = f"{a}.{idx}.3.{b}"
337
- old = f"{a}.{idx}.2.{b}"
338
-
339
- if old in state and new not in state: state[new] = state.pop(old)
340
- super().load_state_dict(state, strict=strict)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/hdemucs.py DELETED
@@ -1,850 +0,0 @@
1
- import math
2
- import torch
3
-
4
- import typing as tp
5
-
6
- from torch import nn
7
- from copy import deepcopy
8
- from typing import Optional
9
-
10
- from torch.nn import functional as F
11
-
12
- from .states import capture_init
13
- from .demucs import DConv, rescale_module
14
-
15
-
16
- def spectro(x, n_fft=512, hop_length=None, pad=0):
17
- *other, length = x.shape
18
- x = x.reshape(-1, length)
19
- device_type = x.device.type
20
- is_other_gpu = not device_type in ["cuda", "cpu"]
21
-
22
- if is_other_gpu: x = x.cpu()
23
-
24
- z = torch.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
25
- _, freqs, frame = z.shape
26
-
27
- return z.view(*other, freqs, frame)
28
-
29
-
30
- def ispectro(z, hop_length=None, length=None, pad=0):
31
- *other, freqs, frames = z.shape
32
- n_fft = 2 * freqs - 2
33
- z = z.view(-1, freqs, frames)
34
- win_length = n_fft // (1 + pad)
35
- device_type = z.device.type
36
- is_other_gpu = not device_type in ["cuda", "cpu"]
37
-
38
- if is_other_gpu: z = z.cpu()
39
-
40
- x = torch.istft(z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
41
- _, length = x.shape
42
-
43
- return x.view(*other, length)
44
-
45
-
46
- def atan2(y, x):
47
- pi = 2 * torch.asin(torch.tensor(1.0))
48
- x += ((x == 0) & (y == 0)) * 1.0
49
-
50
- out = torch.atan(y / x)
51
- out += ((y >= 0) & (x < 0)) * pi
52
- out -= ((y < 0) & (x < 0)) * pi
53
- out *= 1 - ((y > 0) & (x == 0)) * 1.0
54
- out += ((y > 0) & (x == 0)) * (pi / 2)
55
- out *= 1 - ((y < 0) & (x == 0)) * 1.0
56
- out += ((y < 0) & (x == 0)) * (-pi / 2)
57
-
58
- return out
59
-
60
-
61
- def _norm(x: torch.Tensor) -> torch.Tensor:
62
- return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
63
-
64
-
65
- def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
66
- target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
67
-
68
- if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
69
-
70
- if out is a:
71
- real_a = a[..., 0]
72
- out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
73
- out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
74
- else:
75
- out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
76
- out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
77
-
78
- return out
79
-
80
-
81
- def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
82
- target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
83
-
84
- if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
85
-
86
- if out is a:
87
- real_a = a[..., 0]
88
- out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
89
- out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
90
- else:
91
- out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
92
- out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
93
-
94
- return out
95
-
96
-
97
- def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
98
- ez = _norm(z)
99
-
100
- if out is None or out.shape != z.shape: out = torch.zeros_like(z)
101
-
102
- out[..., 0] = z[..., 0] / ez
103
- out[..., 1] = -z[..., 1] / ez
104
-
105
- return out
106
-
107
-
108
- def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
109
- if out is None or out.shape != z.shape: out = torch.zeros_like(z)
110
-
111
- out[..., 0] = z[..., 0]
112
- out[..., 1] = -z[..., 1]
113
-
114
- return out
115
-
116
-
117
- def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
118
- nb_channels = M.shape[-2]
119
-
120
- if out is None or out.shape != M.shape: out = torch.empty_like(M)
121
-
122
- if nb_channels == 1: out = _inv(M, out)
123
- elif nb_channels == 2:
124
- det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
125
- det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
126
- invDet = _inv(det)
127
-
128
- out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
129
- out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
130
- out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
131
- out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
132
- else: raise Exception("Torch == 2 Channels")
133
-
134
- return out
135
-
136
-
137
- def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
138
- (nb_frames, nb_bins, nb_channels) = x.shape[:-1]
139
- nb_sources = y.shape[-1]
140
-
141
- regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
142
- regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
143
-
144
- R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
145
- weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
146
-
147
- v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
148
-
149
- for _ in range(iterations):
150
- v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
151
-
152
- for j in range(nb_sources):
153
- R[j] = torch.tensor(0.0, device=x.device)
154
-
155
- weight = torch.tensor(eps, device=x.device)
156
- pos: int = 0
157
-
158
- batch_size = batch_size if batch_size else nb_frames
159
-
160
- while pos < nb_frames:
161
- t = torch.arange(pos, min(nb_frames, pos + batch_size))
162
- pos = int(t[-1]) + 1
163
-
164
- R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
165
- weight = weight + torch.sum(v[t, ..., j], dim=0)
166
-
167
- R[j] = R[j] / weight[..., None, None, None]
168
- weight = torch.zeros_like(weight)
169
-
170
- if y.requires_grad: y = y.clone()
171
-
172
- pos = 0
173
-
174
- while pos < nb_frames:
175
- t = torch.arange(pos, min(nb_frames, pos + batch_size))
176
- pos = int(t[-1]) + 1
177
-
178
- y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
179
-
180
- Cxx = regularization
181
-
182
- for j in range(nb_sources):
183
- Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
184
-
185
- inv_Cxx = _invert(Cxx)
186
-
187
- for j in range(nb_sources):
188
- gain = torch.zeros_like(inv_Cxx)
189
-
190
- indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
191
-
192
- for index in indices:
193
- gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
194
-
195
- gain = gain * v[t, ..., None, None, None, j]
196
-
197
- for i in range(nb_channels):
198
- y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
199
-
200
- return y, v, R
201
-
202
-
203
- def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
204
- if softmask: y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
205
- else:
206
- angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
207
- nb_sources = targets_spectrograms.shape[-1]
208
- y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
209
- y[..., 0, :] = targets_spectrograms * torch.cos(angle)
210
- y[..., 1, :] = targets_spectrograms * torch.sin(angle)
211
-
212
- if residual: y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
213
- if iterations == 0: return y
214
-
215
- max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
216
-
217
- mix_stft = mix_stft / max_abs
218
- y = y / max_abs
219
-
220
- y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
221
-
222
- y = y * max_abs
223
- return y
224
-
225
-
226
- def _covariance(y_j):
227
- (nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
228
-
229
- Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
230
- indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
231
-
232
- for index in indices:
233
- Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
234
-
235
- return Cj
236
-
237
-
238
- def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
239
- x0 = x
240
- length = x.shape[-1]
241
- padding_left, padding_right = paddings
242
-
243
- if mode == "reflect":
244
- max_pad = max(padding_left, padding_right)
245
-
246
- if length <= max_pad:
247
- extra_pad = max_pad - length + 1
248
- extra_pad_right = min(padding_right, extra_pad)
249
- extra_pad_left = extra_pad - extra_pad_right
250
- paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
251
- x = F.pad(x, (extra_pad_left, extra_pad_right))
252
-
253
- out = F.pad(x, paddings, mode, value)
254
-
255
- assert out.shape[-1] == length + padding_left + padding_right
256
- assert (out[..., padding_left : padding_left + length] == x0).all()
257
-
258
- return out
259
-
260
-
261
- class ScaledEmbedding(nn.Module):
262
- def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
263
- super().__init__()
264
-
265
- self.embedding = nn.Embedding(num_embeddings, embedding_dim)
266
-
267
- if smooth:
268
- weight = torch.cumsum(self.embedding.weight.data, dim=0)
269
- weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
270
-
271
- self.embedding.weight.data[:] = weight
272
-
273
- self.embedding.weight.data /= scale
274
- self.scale = scale
275
-
276
- @property
277
- def weight(self):
278
- return self.embedding.weight * self.scale
279
-
280
- def forward(self, x):
281
- return self.embedding(x) * self.scale
282
-
283
-
284
- class HEncLayer(nn.Module):
285
- def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
286
- super().__init__()
287
- norm_fn = lambda d: nn.Identity()
288
-
289
- if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
290
-
291
- pad = kernel_size // 4 if pad else 0
292
-
293
- klass = nn.Conv1d
294
- self.freq = freq
295
- self.kernel_size = kernel_size
296
- self.stride = stride
297
- self.empty = empty
298
- self.norm = norm
299
- self.pad = pad
300
-
301
- if freq:
302
- kernel_size = [kernel_size, 1]
303
- stride = [stride, 1]
304
- pad = [pad, 0]
305
- klass = nn.Conv2d
306
-
307
- self.conv = klass(chin, chout, kernel_size, stride, pad)
308
-
309
- if self.empty: return
310
-
311
- self.norm1 = norm_fn(chout)
312
- self.rewrite = None
313
-
314
- if rewrite:
315
- self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
316
- self.norm2 = norm_fn(2 * chout)
317
-
318
- self.dconv = None
319
-
320
- if dconv: self.dconv = DConv(chout, **dconv_kw)
321
-
322
- def forward(self, x, inject=None):
323
- if not self.freq and x.dim() == 4:
324
- B, C, Fr, T = x.shape
325
- x = x.view(B, -1, T)
326
-
327
- if not self.freq:
328
- le = x.shape[-1]
329
-
330
- if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride)))
331
-
332
- y = self.conv(x)
333
-
334
- if self.empty: return y
335
-
336
- if inject is not None:
337
- assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
338
-
339
- if inject.dim() == 3 and y.dim() == 4: inject = inject[:, :, None]
340
-
341
- y = y + inject
342
-
343
- y = F.gelu(self.norm1(y))
344
-
345
-
346
- if self.dconv:
347
- if self.freq:
348
- B, C, Fr, T = y.shape
349
- y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
350
-
351
- y = self.dconv(y)
352
-
353
- if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
354
-
355
- if self.rewrite:
356
- z = self.norm2(self.rewrite(y))
357
- z = F.glu(z, dim=1)
358
- else: z = y
359
-
360
- return z
361
-
362
-
363
- class MultiWrap(nn.Module):
364
- def __init__(self, layer, split_ratios):
365
- super().__init__()
366
-
367
- self.split_ratios = split_ratios
368
- self.layers = nn.ModuleList()
369
- self.conv = isinstance(layer, HEncLayer)
370
-
371
- assert not layer.norm
372
- assert layer.freq
373
- assert layer.pad
374
-
375
- if not self.conv: assert not layer.context_freq
376
-
377
- for _ in range(len(split_ratios) + 1):
378
- lay = deepcopy(layer)
379
-
380
- if self.conv: lay.conv.padding = (0, 0)
381
- else: lay.pad = False
382
-
383
- for m in lay.modules():
384
- if hasattr(m, "reset_parameters"): m.reset_parameters()
385
-
386
- self.layers.append(lay)
387
-
388
- def forward(self, x, skip=None, length=None):
389
- B, C, Fr, T = x.shape
390
-
391
- ratios = list(self.split_ratios) + [1]
392
- start = 0
393
- outs = []
394
-
395
- for ratio, layer in zip(ratios, self.layers):
396
- if self.conv:
397
- pad = layer.kernel_size // 4
398
-
399
- if ratio == 1:
400
- limit = Fr
401
- frames = -1
402
- else:
403
- limit = int(round(Fr * ratio))
404
- le = limit - start
405
-
406
- if start == 0: le += pad
407
-
408
- frames = round((le - layer.kernel_size) / layer.stride + 1)
409
- limit = start + (frames - 1) * layer.stride + layer.kernel_size
410
-
411
- if start == 0: limit -= pad
412
-
413
- assert limit - start > 0, (limit, start)
414
- assert limit <= Fr, (limit, Fr)
415
-
416
- y = x[:, :, start:limit, :]
417
-
418
- if start == 0: y = F.pad(y, (0, 0, pad, 0))
419
- if ratio == 1: y = F.pad(y, (0, 0, 0, pad))
420
-
421
- outs.append(layer(y))
422
- start = limit - layer.kernel_size + layer.stride
423
- else:
424
- limit = Fr if ratio == 1 else int(round(Fr * ratio))
425
-
426
- last = layer.last
427
- layer.last = True
428
-
429
- y = x[:, :, start:limit]
430
- s = skip[:, :, start:limit]
431
- out, _ = layer(y, s, None)
432
-
433
- if outs:
434
- outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
435
- out = out[:, :, layer.stride :]
436
-
437
- if ratio == 1: out = out[:, :, : -layer.stride // 2, :]
438
- if start == 0: out = out[:, :, layer.stride // 2 :, :]
439
-
440
- outs.append(out)
441
- layer.last = last
442
- start = limit
443
-
444
- out = torch.cat(outs, dim=2)
445
-
446
- if not self.conv and not last: out = F.gelu(out)
447
-
448
- if self.conv: return out
449
- else: return out, None
450
-
451
-
452
- class HDecLayer(nn.Module):
453
- def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True):
454
- super().__init__()
455
- norm_fn = lambda d: nn.Identity()
456
-
457
- if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
458
-
459
- pad = kernel_size // 4 if pad else 0
460
-
461
- self.pad = pad
462
- self.last = last
463
- self.freq = freq
464
- self.chin = chin
465
- self.empty = empty
466
- self.stride = stride
467
- self.kernel_size = kernel_size
468
- self.norm = norm
469
- self.context_freq = context_freq
470
- klass = nn.Conv1d
471
- klass_tr = nn.ConvTranspose1d
472
-
473
- if freq:
474
- kernel_size = [kernel_size, 1]
475
- stride = [stride, 1]
476
- klass = nn.Conv2d
477
- klass_tr = nn.ConvTranspose2d
478
-
479
- self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
480
- self.norm2 = norm_fn(chout)
481
-
482
- if self.empty: return
483
-
484
- self.rewrite = None
485
-
486
- if rewrite:
487
- if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
488
- else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
489
-
490
- self.norm1 = norm_fn(2 * chin)
491
-
492
- self.dconv = None
493
-
494
- if dconv: self.dconv = DConv(chin, **dconv_kw)
495
-
496
- def forward(self, x, skip, length):
497
- if self.freq and x.dim() == 3:
498
- B, C, T = x.shape
499
- x = x.view(B, self.chin, -1, T)
500
-
501
- if not self.empty:
502
- x = x + skip
503
-
504
- y = F.glu(self.norm1(self.rewrite(x)), dim=1) if self.rewrite else x
505
-
506
- if self.dconv:
507
- if self.freq:
508
- B, C, Fr, T = y.shape
509
- y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
510
-
511
- y = self.dconv(y)
512
-
513
- if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
514
- else:
515
- y = x
516
- assert skip is None
517
-
518
- z = self.norm2(self.conv_tr(y))
519
-
520
- if self.freq:
521
- if self.pad: z = z[..., self.pad : -self.pad, :]
522
- else:
523
- z = z[..., self.pad : self.pad + length]
524
- assert z.shape[-1] == length, (z.shape[-1], length)
525
-
526
- if not self.last: z = F.gelu(z)
527
-
528
- return z, y
529
-
530
-
531
- class HDemucs(nn.Module):
532
- @capture_init
533
- def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=6, rewrite=True, hybrid=True, hybrid_old=False, multi_freqs=None, multi_freqs_depth=2, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, rescale=0.1, samplerate=44100, segment=4 * 10):
534
- super().__init__()
535
-
536
- self.cac = cac
537
- self.wiener_residual = wiener_residual
538
- self.audio_channels = audio_channels
539
- self.sources = sources
540
- self.kernel_size = kernel_size
541
- self.context = context
542
- self.stride = stride
543
- self.depth = depth
544
- self.channels = channels
545
- self.samplerate = samplerate
546
- self.segment = segment
547
- self.nfft = nfft
548
- self.hop_length = nfft // 4
549
- self.wiener_iters = wiener_iters
550
- self.end_iters = end_iters
551
- self.freq_emb = None
552
- self.hybrid = hybrid
553
- self.hybrid_old = hybrid_old
554
-
555
- if hybrid_old: assert hybrid
556
- if hybrid: assert wiener_iters == end_iters
557
-
558
- self.encoder = nn.ModuleList()
559
- self.decoder = nn.ModuleList()
560
-
561
- if hybrid:
562
- self.tencoder = nn.ModuleList()
563
- self.tdecoder = nn.ModuleList()
564
-
565
- chin = audio_channels
566
- chin_z = chin
567
-
568
- if self.cac: chin_z *= 2
569
-
570
- chout = channels_time or channels
571
- chout_z = channels
572
- freqs = nfft // 2
573
-
574
- for index in range(depth):
575
- lstm = index >= dconv_lstm
576
- attn = index >= dconv_attn
577
- norm = index >= norm_starts
578
- freq = freqs > 1
579
- stri = stride
580
- ker = kernel_size
581
-
582
- if not freq:
583
- assert freqs == 1
584
-
585
- ker = time_stride * 2
586
- stri = time_stride
587
-
588
- pad = True
589
- last_freq = False
590
-
591
- if freq and freqs <= kernel_size:
592
- ker = freqs
593
- pad = False
594
- last_freq = True
595
-
596
- kw = {
597
- "kernel_size": ker,
598
- "stride": stri,
599
- "freq": freq,
600
- "pad": pad,
601
- "norm": norm,
602
- "rewrite": rewrite,
603
- "norm_groups": norm_groups,
604
- "dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
605
- }
606
-
607
- kwt = dict(kw)
608
- kwt["freq"] = 0
609
- kwt["kernel_size"] = kernel_size
610
- kwt["stride"] = stride
611
- kwt["pad"] = True
612
- kw_dec = dict(kw)
613
-
614
- multi = False
615
-
616
- if multi_freqs and index < multi_freqs_depth:
617
- multi = True
618
- kw_dec["context_freq"] = False
619
-
620
- if last_freq:
621
- chout_z = max(chout, chout_z)
622
- chout = chout_z
623
-
624
- enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
625
-
626
- if hybrid and freq:
627
- tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
628
- self.tencoder.append(tenc)
629
-
630
- if multi: enc = MultiWrap(enc, multi_freqs)
631
-
632
- self.encoder.append(enc)
633
-
634
- if index == 0:
635
- chin = self.audio_channels * len(self.sources)
636
- chin_z = chin
637
-
638
- if self.cac: chin_z *= 2
639
-
640
- dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
641
-
642
- if multi: dec = MultiWrap(dec, multi_freqs)
643
-
644
- if hybrid and freq:
645
- tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
646
- self.tdecoder.insert(0, tdec)
647
-
648
- self.decoder.insert(0, dec)
649
-
650
- chin = chout
651
- chin_z = chout_z
652
-
653
- chout = int(growth * chout)
654
- chout_z = int(growth * chout_z)
655
-
656
- if freq:
657
- if freqs <= kernel_size: freqs = 1
658
- else: freqs //= stride
659
-
660
- if index == 0 and freq_emb:
661
- self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
662
- self.freq_emb_scale = freq_emb
663
-
664
- if rescale: rescale_module(self, reference=rescale)
665
-
666
- def _spec(self, x):
667
- hl = self.hop_length
668
- nfft = self.nfft
669
-
670
- if self.hybrid:
671
- assert hl == nfft // 4
672
-
673
- le = int(math.ceil(x.shape[-1] / hl))
674
- pad = hl // 2 * 3
675
-
676
- x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") if not self.hybrid_old else pad1d(x, (pad, pad + le * hl - x.shape[-1]))
677
-
678
- z = spectro(x, nfft, hl)[..., :-1, :]
679
-
680
- if self.hybrid:
681
- assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
682
- z = z[..., 2 : 2 + le]
683
-
684
- return z
685
-
686
- def _ispec(self, z, length=None, scale=0):
687
- hl = self.hop_length // (4**scale)
688
- z = F.pad(z, (0, 0, 0, 1))
689
-
690
- if self.hybrid:
691
- z = F.pad(z, (2, 2))
692
- pad = hl // 2 * 3
693
- le = hl * int(math.ceil(length / hl)) + 2 * pad if not self.hybrid_old else hl * int(math.ceil(length / hl))
694
-
695
- x = ispectro(z, hl, length=le)
696
- x = x[..., pad : pad + length] if not self.hybrid_old else x[..., :length]
697
- else: x = ispectro(z, hl, length)
698
-
699
- return x
700
-
701
- def _magnitude(self, z):
702
- if self.cac:
703
- B, C, Fr, T = z.shape
704
- m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
705
- m = m.reshape(B, C * 2, Fr, T)
706
- else: m = z.abs()
707
-
708
- return m
709
-
710
- def _mask(self, z, m):
711
- niters = self.wiener_iters
712
-
713
- if self.cac:
714
- B, S, C, Fr, T = m.shape
715
- out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
716
- out = torch.view_as_complex(out.contiguous())
717
- return out
718
-
719
- if self.training: niters = self.end_iters
720
-
721
- if niters < 0:
722
- z = z[:, None]
723
- return z / (1e-8 + z.abs()) * m
724
- else: return self._wiener(m, z, niters)
725
-
726
- def _wiener(self, mag_out, mix_stft, niters):
727
- init = mix_stft.dtype
728
- wiener_win_len = 300
729
- residual = self.wiener_residual
730
-
731
- B, S, C, Fq, T = mag_out.shape
732
- mag_out = mag_out.permute(0, 4, 3, 2, 1)
733
- mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
734
-
735
- outs = []
736
-
737
- for sample in range(B):
738
- pos = 0
739
- out = []
740
-
741
- for pos in range(0, T, wiener_win_len):
742
- frame = slice(pos, pos + wiener_win_len)
743
- z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
744
- out.append(z_out.transpose(-1, -2))
745
-
746
- outs.append(torch.cat(out, dim=0))
747
-
748
- out = torch.view_as_complex(torch.stack(outs, 0))
749
- out = out.permute(0, 4, 3, 2, 1).contiguous()
750
-
751
- if residual: out = out[:, :-1]
752
-
753
- assert list(out.shape) == [B, S, C, Fq, T]
754
- return out.to(init)
755
-
756
- def forward(self, mix):
757
- x = mix
758
- length = x.shape[-1]
759
-
760
- z = self._spec(mix)
761
- mag = self._magnitude(z).to(mix.device)
762
- x = mag
763
-
764
- B, C, Fq, T = x.shape
765
-
766
- mean = x.mean(dim=(1, 2, 3), keepdim=True)
767
- std = x.std(dim=(1, 2, 3), keepdim=True)
768
- x = (x - mean) / (1e-5 + std)
769
-
770
- if self.hybrid:
771
- xt = mix
772
- meant = xt.mean(dim=(1, 2), keepdim=True)
773
- stdt = xt.std(dim=(1, 2), keepdim=True)
774
- xt = (xt - meant) / (1e-5 + stdt)
775
-
776
- saved = []
777
- saved_t = []
778
- lengths = []
779
- lengths_t = []
780
-
781
- for idx, encode in enumerate(self.encoder):
782
- lengths.append(x.shape[-1])
783
- inject = None
784
-
785
- if self.hybrid and idx < len(self.tencoder):
786
- lengths_t.append(xt.shape[-1])
787
- tenc = self.tencoder[idx]
788
- xt = tenc(xt)
789
-
790
- if not tenc.empty: saved_t.append(xt)
791
- else: inject = xt
792
-
793
- x = encode(x, inject)
794
-
795
- if idx == 0 and self.freq_emb is not None:
796
- frs = torch.arange(x.shape[-2], device=x.device)
797
- emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
798
- x = x + self.freq_emb_scale * emb
799
-
800
- saved.append(x)
801
-
802
- x = torch.zeros_like(x)
803
-
804
- if self.hybrid: xt = torch.zeros_like(x)
805
-
806
- for idx, decode in enumerate(self.decoder):
807
- skip = saved.pop(-1)
808
- x, pre = decode(x, skip, lengths.pop(-1))
809
-
810
- if self.hybrid: offset = self.depth - len(self.tdecoder)
811
-
812
- if self.hybrid and idx >= offset:
813
- tdec = self.tdecoder[idx - offset]
814
- length_t = lengths_t.pop(-1)
815
-
816
- if tdec.empty:
817
- assert pre.shape[2] == 1, pre.shape
818
-
819
- pre = pre[:, :, 0]
820
- xt, _ = tdec(pre, None, length_t)
821
- else:
822
- skip = saved_t.pop(-1)
823
- xt, _ = tdec(xt, skip, length_t)
824
-
825
- assert len(saved) == 0
826
- assert len(lengths_t) == 0
827
- assert len(saved_t) == 0
828
-
829
- S = len(self.sources)
830
-
831
- x = x.view(B, S, -1, Fq, T)
832
- x = x * std[:, None] + mean[:, None]
833
-
834
- device_type = x.device.type
835
- device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
836
- x_is_other_gpu = not device_type in ["cuda", "cpu"]
837
-
838
- if x_is_other_gpu: x = x.cpu()
839
-
840
- zout = self._mask(z, x)
841
- x = self._ispec(zout, length)
842
-
843
- if x_is_other_gpu: x = x.to(device_load)
844
-
845
- if self.hybrid:
846
- xt = xt.view(B, S, -1, length)
847
- xt = xt * stdt[:, None] + meant[:, None]
848
- x = xt + x
849
-
850
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/htdemucs.py DELETED
@@ -1,690 +0,0 @@
1
- import os
2
- import sys
3
- import math
4
- import torch
5
- import random
6
-
7
- import numpy as np
8
- import typing as tp
9
-
10
- from torch import nn
11
- from einops import rearrange
12
- from fractions import Fraction
13
- from torch.nn import functional as F
14
-
15
-
16
- now_dir = os.getcwd()
17
- sys.path.append(now_dir)
18
-
19
- from .states import capture_init
20
- from .demucs import rescale_module
21
- from main.configs.config import Config
22
- from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
23
-
24
- translations = Config().translations
25
-
26
-
27
- def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000):
28
- assert dim % 2 == 0
29
-
30
- pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
31
- half_dim = dim // 2
32
-
33
- adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
34
- phase = pos / (max_period ** (adim / (half_dim - 1)))
35
-
36
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
37
-
38
- def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
39
- if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model))
40
-
41
- pe = torch.zeros(d_model, height, width)
42
-
43
- d_model = int(d_model / 2)
44
-
45
- div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
46
-
47
- pos_w = torch.arange(0.0, width).unsqueeze(1)
48
- pos_h = torch.arange(0.0, height).unsqueeze(1)
49
-
50
- pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
51
- pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
52
- pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
53
- pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
54
-
55
- return pe[None, :].to(device)
56
-
57
- def create_sin_embedding_cape( length: int, dim: int, batch_size: int, mean_normalize: bool, augment: bool, max_global_shift: float = 0.0, max_local_shift: float = 0.0, max_scale: float = 1.0, device: str = "cpu", max_period: float = 10000.0):
58
- assert dim % 2 == 0
59
-
60
- pos = 1.0 * torch.arange(length).view(-1, 1, 1)
61
- pos = pos.repeat(1, batch_size, 1)
62
-
63
- if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True)
64
-
65
- if augment:
66
- delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
67
- delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
68
-
69
- log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
70
- pos = (pos + delta + delta_local) * np.exp(log_lambdas)
71
-
72
- pos = pos.to(device)
73
-
74
- half_dim = dim // 2
75
- adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
76
- phase = pos / (max_period ** (adim / (half_dim - 1)))
77
-
78
- return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
79
-
80
-
81
- class MyGroupNorm(nn.GroupNorm):
82
- def __init__(self, *args, **kwargs):
83
- super().__init__(*args, **kwargs)
84
-
85
- def forward(self, x):
86
- x = x.transpose(1, 2)
87
- return super().forward(x).transpose(1, 2)
88
-
89
- class LayerScale(nn.Module):
90
- def __init__(self, channels: int, init: float = 0, channel_last=False):
91
- super().__init__()
92
- self.channel_last = channel_last
93
- self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
94
- self.scale.data[:] = init
95
-
96
- def forward(self, x):
97
- if self.channel_last: return self.scale * x
98
- else: return self.scale[:, None] * x
99
-
100
- class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
101
- def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False):
102
- factory_kwargs = {"device": device, "dtype": dtype}
103
- super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype)
104
-
105
- self.auto_sparsity = auto_sparsity
106
-
107
- if group_norm:
108
- self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
109
- self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
110
-
111
- self.norm_out = None
112
-
113
- if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
114
-
115
- self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
116
- self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
117
-
118
- def forward(self, src, src_mask=None, src_key_padding_mask=None):
119
- x = src
120
- T, B, C = x.shape
121
-
122
- if self.norm_first:
123
- x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
124
- x = x + self.gamma_2(self._ff_block(self.norm2(x)))
125
-
126
- if self.norm_out: x = self.norm_out(x)
127
- else:
128
- x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
129
- x = self.norm2(x + self.gamma_2(self._ff_block(x)))
130
-
131
- return x
132
-
133
- class CrossTransformerEncoder(nn.Module):
134
- def __init__(self, dim: int, emb: str = "sin", hidden_scale: float = 4.0, num_heads: int = 8, num_layers: int = 6, cross_first: bool = False, dropout: float = 0.0, max_positions: int = 1000, norm_in: bool = True, norm_in_group: bool = False, group_norm: int = False, norm_first: bool = False, norm_out: bool = False, max_period: float = 10000.0, weight_decay: float = 0.0, lr: tp.Optional[float] = None, layer_scale: bool = False, gelu: bool = True, sin_random_shift: int = 0, weight_pos_embed: float = 1.0, cape_mean_normalize: bool = True, cape_augment: bool = True, cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], sparse_self_attn: bool = False, sparse_cross_attn: bool = False, mask_type: str = "diag", mask_random_seed: int = 42, sparse_attn_window: int = 500, global_window: int = 50, auto_sparsity: bool = False, sparsity: float = 0.95):
135
- super().__init__()
136
- assert dim % num_heads == 0
137
-
138
- hidden_dim = int(dim * hidden_scale)
139
-
140
- self.num_layers = num_layers
141
- self.classic_parity = 1 if cross_first else 0
142
- self.emb = emb
143
- self.max_period = max_period
144
- self.weight_decay = weight_decay
145
- self.weight_pos_embed = weight_pos_embed
146
- self.sin_random_shift = sin_random_shift
147
-
148
- if emb == "cape":
149
- self.cape_mean_normalize = cape_mean_normalize
150
- self.cape_augment = cape_augment
151
- self.cape_glob_loc_scale = cape_glob_loc_scale
152
-
153
- if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
154
-
155
- self.lr = lr
156
-
157
- activation: tp.Any = F.gelu if gelu else F.relu
158
-
159
- self.norm_in: nn.Module
160
- self.norm_in_t: nn.Module
161
-
162
- if norm_in:
163
- self.norm_in = nn.LayerNorm(dim)
164
- self.norm_in_t = nn.LayerNorm(dim)
165
- elif norm_in_group:
166
- self.norm_in = MyGroupNorm(int(norm_in_group), dim)
167
- self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
168
- else:
169
- self.norm_in = nn.Identity()
170
- self.norm_in_t = nn.Identity()
171
-
172
- self.layers = nn.ModuleList()
173
- self.layers_t = nn.ModuleList()
174
-
175
- kwargs_common = {
176
- "d_model": dim,
177
- "nhead": num_heads,
178
- "dim_feedforward": hidden_dim,
179
- "dropout": dropout,
180
- "activation": activation,
181
- "group_norm": group_norm,
182
- "norm_first": norm_first,
183
- "norm_out": norm_out,
184
- "layer_scale": layer_scale,
185
- "mask_type": mask_type,
186
- "mask_random_seed": mask_random_seed,
187
- "sparse_attn_window": sparse_attn_window,
188
- "global_window": global_window,
189
- "sparsity": sparsity,
190
- "auto_sparsity": auto_sparsity,
191
- "batch_first": True,
192
- }
193
-
194
- kwargs_classic_encoder = dict(kwargs_common)
195
- kwargs_classic_encoder.update({"sparse": sparse_self_attn})
196
- kwargs_cross_encoder = dict(kwargs_common)
197
- kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
198
-
199
- for idx in range(num_layers):
200
- if idx % 2 == self.classic_parity:
201
- self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
202
- self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
203
- else:
204
- self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
205
- self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
206
-
207
- def forward(self, x, xt):
208
- B, C, Fr, T1 = x.shape
209
-
210
- pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period)
211
- pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
212
-
213
- x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
214
- x = self.norm_in(x)
215
- x = x + self.weight_pos_embed * pos_emb_2d
216
-
217
- B, C, T2 = xt.shape
218
- xt = rearrange(xt, "b c t2 -> b t2 c")
219
-
220
- pos_emb = self._get_pos_embedding(T2, B, C, x.device)
221
- pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
222
-
223
- xt = self.norm_in_t(xt)
224
- xt = xt + self.weight_pos_embed * pos_emb
225
-
226
- for idx in range(self.num_layers):
227
- if idx % 2 == self.classic_parity:
228
- x = self.layers[idx](x)
229
- xt = self.layers_t[idx](xt)
230
- else:
231
- old_x = x
232
- x = self.layers[idx](x, xt)
233
- xt = self.layers_t[idx](xt, old_x)
234
-
235
- x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
236
- xt = rearrange(xt, "b t2 c -> b c t2")
237
-
238
- return x, xt
239
-
240
- def _get_pos_embedding(self, T, B, C, device):
241
- if self.emb == "sin":
242
- shift = random.randrange(self.sin_random_shift + 1)
243
- pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
244
- elif self.emb == "cape":
245
- if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2])
246
- else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
247
-
248
- elif self.emb == "scaled":
249
- pos = torch.arange(T, device=device)
250
- pos_emb = self.position_embeddings(pos)[:, None]
251
-
252
- return pos_emb
253
-
254
- def make_optim_group(self):
255
- group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
256
- if self.lr is not None: group["lr"] = self.lr
257
-
258
- return group
259
-
260
-
261
- class CrossTransformerEncoderLayer(nn.Module):
262
- def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation=F.relu, layer_norm_eps: float = 1e-5, layer_scale: bool = False, init_values: float = 1e-4, norm_first: bool = False, group_norm: bool = False, norm_out: bool = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False):
263
- factory_kwargs = {"device": device, "dtype": dtype}
264
- super().__init__()
265
-
266
- self.auto_sparsity = auto_sparsity
267
-
268
- self.cross_attn: nn.Module
269
- self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
270
-
271
- self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
272
- self.dropout = nn.Dropout(dropout)
273
- self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
274
-
275
- self.norm_first = norm_first
276
-
277
- self.norm1: nn.Module
278
- self.norm2: nn.Module
279
- self.norm3: nn.Module
280
-
281
- if group_norm:
282
- self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
283
- self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
284
- self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
285
- else:
286
- self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
287
- self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
288
- self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
289
-
290
- self.norm_out = None
291
- if self.norm_first & norm_out:
292
- self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
293
-
294
- self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
295
- self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
296
-
297
- self.dropout1 = nn.Dropout(dropout)
298
- self.dropout2 = nn.Dropout(dropout)
299
-
300
- if isinstance(activation, str): self.activation = self._get_activation_fn(activation)
301
- else: self.activation = activation
302
-
303
- def forward(self, q, k, mask=None):
304
- if self.norm_first:
305
- x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
306
- x = x + self.gamma_2(self._ff_block(self.norm3(x)))
307
-
308
- if self.norm_out: x = self.norm_out(x)
309
- else:
310
- x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
311
- x = self.norm2(x + self.gamma_2(self._ff_block(x)))
312
-
313
- return x
314
-
315
- def _ca_block(self, q, k, attn_mask=None):
316
- x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
317
- return self.dropout1(x)
318
-
319
- def _ff_block(self, x):
320
- x = self.linear2(self.dropout(self.activation(self.linear1(x))))
321
- return self.dropout2(x)
322
-
323
- def _get_activation_fn(self, activation):
324
- if activation == "relu": return F.relu
325
- elif activation == "gelu": return F.gelu
326
-
327
- raise RuntimeError(translations["activation"].format(activation=activation))
328
-
329
-
330
- class HTDemucs(nn.Module):
331
- @capture_init
332
- def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True):
333
- super().__init__()
334
-
335
- self.cac = cac
336
- self.wiener_residual = wiener_residual
337
- self.audio_channels = audio_channels
338
- self.sources = sources
339
- self.kernel_size = kernel_size
340
- self.context = context
341
- self.stride = stride
342
- self.depth = depth
343
- self.bottom_channels = bottom_channels
344
- self.channels = channels
345
- self.samplerate = samplerate
346
- self.segment = segment
347
- self.use_train_segment = use_train_segment
348
- self.nfft = nfft
349
- self.hop_length = nfft // 4
350
- self.wiener_iters = wiener_iters
351
- self.end_iters = end_iters
352
- self.freq_emb = None
353
-
354
- assert wiener_iters == end_iters
355
-
356
- self.encoder = nn.ModuleList()
357
- self.decoder = nn.ModuleList()
358
- self.tencoder = nn.ModuleList()
359
- self.tdecoder = nn.ModuleList()
360
-
361
- chin = audio_channels
362
- chin_z = chin
363
-
364
- if self.cac: chin_z *= 2
365
-
366
- chout = channels_time or channels
367
- chout_z = channels
368
- freqs = nfft // 2
369
-
370
- for index in range(depth):
371
- norm = index >= norm_starts
372
- freq = freqs > 1
373
- stri = stride
374
- ker = kernel_size
375
-
376
- if not freq:
377
- assert freqs == 1
378
-
379
- ker = time_stride * 2
380
- stri = time_stride
381
-
382
- pad = True
383
- last_freq = False
384
-
385
- if freq and freqs <= kernel_size:
386
- ker = freqs
387
- pad = False
388
- last_freq = True
389
-
390
- kw = {
391
- "kernel_size": ker,
392
- "stride": stri,
393
- "freq": freq,
394
- "pad": pad,
395
- "norm": norm,
396
- "rewrite": rewrite,
397
- "norm_groups": norm_groups,
398
- "dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
399
- }
400
-
401
- kwt = dict(kw)
402
- kwt["freq"] = 0
403
- kwt["kernel_size"] = kernel_size
404
- kwt["stride"] = stride
405
- kwt["pad"] = True
406
- kw_dec = dict(kw)
407
-
408
- multi = False
409
-
410
- if multi_freqs and index < multi_freqs_depth:
411
- multi = True
412
- kw_dec["context_freq"] = False
413
-
414
- if last_freq:
415
- chout_z = max(chout, chout_z)
416
- chout = chout_z
417
-
418
- enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
419
-
420
- if freq:
421
- tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
422
- self.tencoder.append(tenc)
423
-
424
- if multi: enc = MultiWrap(enc, multi_freqs)
425
-
426
- self.encoder.append(enc)
427
-
428
- if index == 0:
429
- chin = self.audio_channels * len(self.sources)
430
- chin_z = chin
431
-
432
- if self.cac: chin_z *= 2
433
-
434
- dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
435
-
436
- if multi:
437
- dec = MultiWrap(dec, multi_freqs)
438
-
439
- if freq:
440
- tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
441
- self.tdecoder.insert(0, tdec)
442
-
443
- self.decoder.insert(0, dec)
444
-
445
- chin = chout
446
- chin_z = chout_z
447
-
448
- chout = int(growth * chout)
449
- chout_z = int(growth * chout_z)
450
-
451
- if freq:
452
- if freqs <= kernel_size: freqs = 1
453
- else: freqs //= stride
454
-
455
- if index == 0 and freq_emb:
456
- self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
457
- self.freq_emb_scale = freq_emb
458
-
459
- if rescale: rescale_module(self, reference=rescale)
460
-
461
- transformer_channels = channels * growth ** (depth - 1)
462
-
463
- if bottom_channels:
464
- self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
465
- self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
466
- self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
467
- self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
468
-
469
- transformer_channels = bottom_channels
470
-
471
- if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity)
472
- else: self.crosstransformer = None
473
-
474
- def _spec(self, x):
475
- hl = self.hop_length
476
- nfft = self.nfft
477
-
478
- assert hl == nfft // 4
479
-
480
- le = int(math.ceil(x.shape[-1] / hl))
481
- pad = hl // 2 * 3
482
-
483
- x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
484
-
485
- z = spectro(x, nfft, hl)[..., :-1, :]
486
-
487
- assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
488
-
489
- z = z[..., 2 : 2 + le]
490
-
491
- return z
492
-
493
- def _ispec(self, z, length=None, scale=0):
494
- hl = self.hop_length // (4**scale)
495
- z = F.pad(z, (0, 0, 0, 1))
496
- z = F.pad(z, (2, 2))
497
-
498
- pad = hl // 2 * 3
499
- le = hl * int(math.ceil(length / hl)) + 2 * pad
500
-
501
- x = ispectro(z, hl, length=le)
502
- x = x[..., pad : pad + length]
503
-
504
- return x
505
-
506
- def _magnitude(self, z):
507
- if self.cac:
508
- B, C, Fr, T = z.shape
509
- m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
510
- m = m.reshape(B, C * 2, Fr, T)
511
- else: m = z.abs()
512
-
513
- return m
514
-
515
- def _mask(self, z, m):
516
- niters = self.wiener_iters
517
- if self.cac:
518
- B, S, C, Fr, T = m.shape
519
- out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
520
- out = torch.view_as_complex(out.contiguous())
521
- return out
522
-
523
- if self.training: niters = self.end_iters
524
-
525
- if niters < 0:
526
- z = z[:, None]
527
- return z / (1e-8 + z.abs()) * m
528
- else: return self._wiener(m, z, niters)
529
-
530
- def _wiener(self, mag_out, mix_stft, niters):
531
- init = mix_stft.dtype
532
- wiener_win_len = 300
533
- residual = self.wiener_residual
534
-
535
- B, S, C, Fq, T = mag_out.shape
536
- mag_out = mag_out.permute(0, 4, 3, 2, 1)
537
- mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
538
-
539
- outs = []
540
-
541
- for sample in range(B):
542
- pos = 0
543
- out = []
544
-
545
- for pos in range(0, T, wiener_win_len):
546
- frame = slice(pos, pos + wiener_win_len)
547
- z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
548
- out.append(z_out.transpose(-1, -2))
549
-
550
- outs.append(torch.cat(out, dim=0))
551
-
552
- out = torch.view_as_complex(torch.stack(outs, 0))
553
- out = out.permute(0, 4, 3, 2, 1).contiguous()
554
-
555
- if residual: out = out[:, :-1]
556
-
557
- assert list(out.shape) == [B, S, C, Fq, T]
558
-
559
- return out.to(init)
560
-
561
- def valid_length(self, length: int):
562
- if not self.use_train_segment: return length
563
-
564
- training_length = int(self.segment * self.samplerate)
565
- if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length))
566
-
567
- return training_length
568
-
569
- def forward(self, mix):
570
- length = mix.shape[-1]
571
- length_pre_pad = None
572
-
573
- if self.use_train_segment:
574
- if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate)
575
- else:
576
- training_length = int(self.segment * self.samplerate)
577
-
578
- if mix.shape[-1] < training_length:
579
- length_pre_pad = mix.shape[-1]
580
- mix = F.pad(mix, (0, training_length - length_pre_pad))
581
-
582
- z = self._spec(mix)
583
- mag = self._magnitude(z).to(mix.device)
584
- x = mag
585
-
586
- B, C, Fq, T = x.shape
587
-
588
- mean = x.mean(dim=(1, 2, 3), keepdim=True)
589
- std = x.std(dim=(1, 2, 3), keepdim=True)
590
- x = (x - mean) / (1e-5 + std)
591
-
592
- xt = mix
593
- meant = xt.mean(dim=(1, 2), keepdim=True)
594
- stdt = xt.std(dim=(1, 2), keepdim=True)
595
- xt = (xt - meant) / (1e-5 + stdt)
596
-
597
- saved = []
598
- saved_t = []
599
- lengths = []
600
- lengths_t = []
601
-
602
- for idx, encode in enumerate(self.encoder):
603
- lengths.append(x.shape[-1])
604
- inject = None
605
-
606
- if idx < len(self.tencoder):
607
- lengths_t.append(xt.shape[-1])
608
- tenc = self.tencoder[idx]
609
- xt = tenc(xt)
610
-
611
- if not tenc.empty: saved_t.append(xt)
612
- else: inject = xt
613
-
614
- x = encode(x, inject)
615
-
616
- if idx == 0 and self.freq_emb is not None:
617
- frs = torch.arange(x.shape[-2], device=x.device)
618
- emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
619
- x = x + self.freq_emb_scale * emb
620
-
621
- saved.append(x)
622
-
623
-
624
- if self.crosstransformer:
625
- if self.bottom_channels:
626
- b, c, f, t = x.shape
627
- x = rearrange(x, "b c f t-> b c (f t)")
628
- x = self.channel_upsampler(x)
629
- x = rearrange(x, "b c (f t)-> b c f t", f=f)
630
-
631
- xt = self.channel_upsampler_t(xt)
632
-
633
- x, xt = self.crosstransformer(x, xt)
634
-
635
- if self.bottom_channels:
636
- x = rearrange(x, "b c f t-> b c (f t)")
637
- x = self.channel_downsampler(x)
638
- x = rearrange(x, "b c (f t)-> b c f t", f=f)
639
-
640
- xt = self.channel_downsampler_t(xt)
641
-
642
- for idx, decode in enumerate(self.decoder):
643
- skip = saved.pop(-1)
644
- x, pre = decode(x, skip, lengths.pop(-1))
645
-
646
- offset = self.depth - len(self.tdecoder)
647
-
648
- if idx >= offset:
649
- tdec = self.tdecoder[idx - offset]
650
- length_t = lengths_t.pop(-1)
651
-
652
- if tdec.empty:
653
- assert pre.shape[2] == 1, pre.shape
654
- pre = pre[:, :, 0]
655
- xt, _ = tdec(pre, None, length_t)
656
- else:
657
- skip = saved_t.pop(-1)
658
- xt, _ = tdec(xt, skip, length_t)
659
-
660
- assert len(saved) == 0
661
- assert len(lengths_t) == 0
662
- assert len(saved_t) == 0
663
-
664
-
665
- S = len(self.sources)
666
- x = x.view(B, S, -1, Fq, T)
667
- x = x * std[:, None] + mean[:, None]
668
-
669
- device_type = x.device.type
670
- device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
671
- x_is_other_gpu = not device_type in ["cuda", "cpu"]
672
-
673
- if x_is_other_gpu: x = x.cpu()
674
-
675
- zout = self._mask(z, x)
676
-
677
- if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length)
678
- else: x = self._ispec(zout, length)
679
-
680
- if x_is_other_gpu: x = x.to(device_load)
681
-
682
- if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length)
683
- else: xt = xt.view(B, S, -1, length)
684
-
685
- xt = xt * stdt[:, None] + meant[:, None]
686
- x = xt + x
687
-
688
- if length_pre_pad: x = x[..., :length_pre_pad]
689
-
690
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/states.py DELETED
@@ -1,70 +0,0 @@
1
- import os
2
- import sys
3
- import torch
4
- import inspect
5
- import warnings
6
- import functools
7
-
8
- from pathlib import Path
9
- from diffq import restore_quantized_state
10
-
11
- now_dir = os.getcwd()
12
- sys.path.append(now_dir)
13
-
14
- from main.configs.config import Config
15
-
16
- translations = Config().translations
17
-
18
-
19
- def load_model(path_or_package, strict=False):
20
- if isinstance(path_or_package, dict): package = path_or_package
21
- elif isinstance(path_or_package, (str, Path)):
22
- with warnings.catch_warnings():
23
- warnings.simplefilter("ignore")
24
-
25
- path = path_or_package
26
- package = torch.load(path, map_location="cpu")
27
- else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
28
-
29
-
30
- klass = package["klass"]
31
- args = package["args"]
32
- kwargs = package["kwargs"]
33
-
34
-
35
- if strict: model = klass(*args, **kwargs)
36
- else:
37
- sig = inspect.signature(klass)
38
-
39
- for key in list(kwargs):
40
- if key not in sig.parameters:
41
- warnings.warn(translations["del_parameter"] + key)
42
-
43
- del kwargs[key]
44
-
45
- model = klass(*args, **kwargs)
46
-
47
- state = package["state"]
48
-
49
- set_state(model, state)
50
-
51
- return model
52
-
53
-
54
- def set_state(model, state, quantizer=None):
55
- if state.get("__quantized"):
56
- if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
57
- else: restore_quantized_state(model, state)
58
- else: model.load_state_dict(state)
59
-
60
- return state
61
-
62
-
63
- def capture_init(init):
64
- @functools.wraps(init)
65
-
66
- def __init__(self, *args, **kwargs):
67
- self._init_args_kwargs = (args, kwargs)
68
- init(self, *args, **kwargs)
69
-
70
- return __init__
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/demucs/utils.py DELETED
@@ -1,10 +0,0 @@
1
- import torch
2
- import typing as tp
3
-
4
- def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
5
- ref_size: int
6
- ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
7
- delta = tensor.size(-1) - ref_size
8
- if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
9
- if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
10
- return tensor
 
 
 
 
 
 
 
 
 
 
 
main/library/uvr5_separator/spec_utils.py DELETED
@@ -1,1100 +0,0 @@
1
- import io
2
- import os
3
- import six
4
- import sys
5
- import math
6
- import librosa
7
- import tempfile
8
- import platform
9
- import traceback
10
- import audioread
11
- import subprocess
12
-
13
- import numpy as np
14
- import soundfile as sf
15
-
16
- from scipy.signal import correlate, hilbert
17
-
18
- now_dir = os.getcwd()
19
- sys.path.append(now_dir)
20
-
21
- from main.configs.config import Config
22
- translations = Config().translations
23
-
24
- OPERATING_SYSTEM = platform.system()
25
- SYSTEM_ARCH = platform.platform()
26
- SYSTEM_PROC = platform.processor()
27
- ARM = "arm"
28
- AUTO_PHASE = "Automatic"
29
- POSITIVE_PHASE = "Positive Phase"
30
- NEGATIVE_PHASE = "Negative Phase"
31
- NONE_P = ("None",)
32
- LOW_P = ("Shifts: Low",)
33
- MED_P = ("Shifts: Medium",)
34
- HIGH_P = ("Shifts: High",)
35
- VHIGH_P = "Shifts: Very High"
36
- MAXIMUM_P = "Shifts: Maximum"
37
- BASE_PATH_RUB = sys._MEIPASS if getattr(sys, 'frozen', False) else os.path.dirname(os.path.abspath(__file__))
38
- DEVNULL = open(os.devnull, 'w') if six.PY2 else subprocess.DEVNULL
39
- MAX_SPEC = "Max Spec"
40
- MIN_SPEC = "Min Spec"
41
- LIN_ENSE = "Linear Ensemble"
42
- MAX_WAV = MAX_SPEC
43
- MIN_WAV = MIN_SPEC
44
- AVERAGE = "Average"
45
-
46
- progress_value = 0
47
- last_update_time = 0
48
- is_macos = False
49
-
50
- if OPERATING_SYSTEM == "Darwin":
51
- wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
52
- wav_resolution_float_resampling = "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
53
- is_macos = True
54
- else:
55
- wav_resolution = "sinc_fastest"
56
- wav_resolution_float_resampling = wav_resolution
57
-
58
- def crop_center(h1, h2):
59
- h1_shape = h1.size()
60
- h2_shape = h2.size()
61
-
62
- if h1_shape[3] == h2_shape[3]: return h1
63
- elif h1_shape[3] < h2_shape[3]: raise ValueError("h1_shape[3] > h2_shape[3]")
64
-
65
- s_time = (h1_shape[3] - h2_shape[3]) // 2
66
- e_time = s_time + h2_shape[3]
67
-
68
- h1 = h1[:, :, :, s_time:e_time]
69
-
70
- return h1
71
-
72
- def preprocess(X_spec):
73
- return np.abs(X_spec), np.angle(X_spec)
74
-
75
- def make_padding(width, cropsize, offset):
76
- left = offset
77
- roi_size = cropsize - offset * 2
78
-
79
- if roi_size == 0: roi_size = cropsize
80
-
81
- right = roi_size - (width % roi_size) + left
82
-
83
- return left, right, roi_size
84
-
85
- def normalize(wave, max_peak=1.0):
86
- maxv = np.abs(wave).max()
87
-
88
- if maxv > max_peak: wave *= max_peak / maxv
89
-
90
- return wave
91
-
92
- def auto_transpose(audio_array: np.ndarray):
93
- if audio_array.shape[1] == 2: return audio_array.T
94
-
95
- return audio_array
96
-
97
- def write_array_to_mem(audio_data, subtype):
98
- if isinstance(audio_data, np.ndarray):
99
- audio_buffer = io.BytesIO()
100
- sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
101
-
102
- audio_buffer.seek(0)
103
-
104
- return audio_buffer
105
- else: return audio_data
106
-
107
- def spectrogram_to_image(spec, mode="magnitude"):
108
- if mode == "magnitude":
109
- y = np.abs(spec) if np.iscomplexobj(spec) else spec
110
- y = np.log10(y**2 + 1e-8)
111
- elif mode == "phase":
112
- y = np.angle(spec) if np.iscomplexobj(spec) else spec
113
-
114
- y -= y.min()
115
- y *= 255 / y.max()
116
- img = np.uint8(y)
117
-
118
- if y.ndim == 3:
119
- img = img.transpose(1, 2, 0)
120
- img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
121
-
122
- return img
123
-
124
- def reduce_vocal_aggressively(X, y, softmask):
125
- v = X - y
126
-
127
- y_mag_tmp = np.abs(y)
128
- v_mag_tmp = np.abs(v)
129
-
130
- v_mask = v_mag_tmp > y_mag_tmp
131
- y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
132
-
133
- return y_mag * np.exp(1.0j * np.angle(y))
134
-
135
- def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
136
- mask = y_mask
137
-
138
- try:
139
- if min_range < fade_size * 2: raise ValueError("min_range >= fade_size * 2")
140
-
141
- idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
142
- start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
143
- end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
144
- artifact_idx = np.where(end_idx - start_idx > min_range)[0]
145
- weight = np.zeros_like(y_mask)
146
-
147
- if len(artifact_idx) > 0:
148
- start_idx = start_idx[artifact_idx]
149
- end_idx = end_idx[artifact_idx]
150
- old_e = None
151
-
152
- for s, e in zip(start_idx, end_idx):
153
- if old_e is not None and s - old_e < fade_size: s = old_e - fade_size * 2
154
-
155
- if s != 0: weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
156
- else: s -= fade_size
157
-
158
- if e != y_mask.shape[2]: weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
159
- else: e += fade_size
160
-
161
- weight[:, :, s + fade_size : e - fade_size] = 1
162
- old_e = e
163
-
164
- v_mask = 1 - y_mask
165
- y_mask += weight * v_mask
166
- mask = y_mask
167
- except Exception as e:
168
- error_name = f"{type(e).__name__}"
169
- traceback_text = "".join(traceback.format_tb(e.__traceback__))
170
- message = f'{error_name}: "{e}"\n{traceback_text}"'
171
- print(translations["not_success"], message)
172
-
173
- return mask
174
-
175
- def align_wave_head_and_tail(a, b):
176
- l = min([a[0].size, b[0].size])
177
-
178
- return a[:l, :l], b[:l, :l]
179
-
180
- def convert_channels(spec, mp, band):
181
- cc = mp.param["band"][band].get("convert_channels")
182
-
183
- if "mid_side_c" == cc:
184
- spec_left = np.add(spec[0], spec[1] * 0.25)
185
- spec_right = np.subtract(spec[1], spec[0] * 0.25)
186
- elif "mid_side" == cc:
187
- spec_left = np.add(spec[0], spec[1]) / 2
188
- spec_right = np.subtract(spec[0], spec[1])
189
- elif "stereo_n" == cc:
190
- spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
191
- spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
192
- else: return spec
193
-
194
- return np.asfortranarray([spec_left, spec_right])
195
-
196
- def combine_spectrograms(specs, mp, is_v51_model=False):
197
- l = min([specs[i].shape[2] for i in specs])
198
- spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
199
- offset = 0
200
- bands_n = len(mp.param["band"])
201
-
202
- for d in range(1, bands_n + 1):
203
- h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
204
- spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
205
- offset += h
206
-
207
- if offset > mp.param["bins"]: raise ValueError("Quá nhiều thùng")
208
-
209
- if mp.param["pre_filter_start"] > 0:
210
- if is_v51_model: spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
211
- else:
212
- if bands_n == 1: spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
213
- else:
214
- gp = 1
215
-
216
- for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
217
- g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
218
- gp = g
219
- spec_c[:, b, :] *= g
220
-
221
- return np.asfortranarray(spec_c)
222
-
223
- def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
224
- if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
225
-
226
- if not is_v51_model:
227
- if mp.param["reverse"]:
228
- wave_left = np.flip(np.asfortranarray(wave[0]))
229
- wave_right = np.flip(np.asfortranarray(wave[1]))
230
- elif mp.param["mid_side"]:
231
- wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
232
- wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
233
- elif mp.param["mid_side_b2"]:
234
- wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
235
- wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
236
- else:
237
- wave_left = np.asfortranarray(wave[0])
238
- wave_right = np.asfortranarray(wave[1])
239
- else:
240
- wave_left = np.asfortranarray(wave[0])
241
- wave_right = np.asfortranarray(wave[1])
242
-
243
- spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
244
- spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
245
-
246
- spec = np.asfortranarray([spec_left, spec_right])
247
-
248
- if is_v51_model: spec = convert_channels(spec, mp, band)
249
-
250
- return spec
251
-
252
- def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
253
- spec_left = np.asfortranarray(spec[0])
254
- spec_right = np.asfortranarray(spec[1])
255
-
256
- wave_left = librosa.istft(spec_left, hop_length=hop_length)
257
- wave_right = librosa.istft(spec_right, hop_length=hop_length)
258
-
259
- if is_v51_model:
260
- cc = mp.param["band"][band].get("convert_channels")
261
-
262
- if "mid_side_c" == cc: return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
263
- elif "mid_side" == cc: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
264
- elif "stereo_n" == cc: return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
265
- else:
266
- if mp.param["reverse"]: return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
267
- elif mp.param["mid_side"]: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
268
- elif mp.param["mid_side_b2"]: return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
269
-
270
- return np.asfortranarray([wave_left, wave_right])
271
-
272
- def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
273
- bands_n = len(mp.param["band"])
274
- offset = 0
275
-
276
- for d in range(1, bands_n + 1):
277
- bp = mp.param["band"][d]
278
- spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
279
- h = bp["crop_stop"] - bp["crop_start"]
280
- spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
281
-
282
- offset += h
283
-
284
- if d == bands_n:
285
- if extra_bins_h:
286
- max_bin = bp["n_fft"] // 2
287
- spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
288
-
289
- if bp["hpf_start"] > 0:
290
- if is_v51_model: spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
291
- else: spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
292
-
293
- if bands_n == 1: wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
294
- else: wave = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
295
- else:
296
- sr = mp.param["band"][d + 1]["sr"]
297
- if d == 1:
298
- if is_v51_model: spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
299
- else: spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
300
-
301
- try:
302
- wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
303
- except ValueError as e:
304
- print(f"{translations['resample_error']}: {e}")
305
- print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
306
- else:
307
- if is_v51_model:
308
- spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
309
- spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
310
- else:
311
- spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
312
- spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
313
-
314
- wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
315
-
316
- try:
317
- wave = librosa.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
318
- except ValueError as e:
319
- print(f"{translations['resample_error']}: {e}")
320
- print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
321
-
322
- return wave
323
-
324
- def get_lp_filter_mask(n_bins, bin_start, bin_stop):
325
- return np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
326
-
327
- def get_hp_filter_mask(n_bins, bin_start, bin_stop):
328
- return np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
329
-
330
- def fft_lp_filter(spec, bin_start, bin_stop):
331
- g = 1.0
332
-
333
- for b in range(bin_start, bin_stop):
334
- g -= 1 / (bin_stop - bin_start)
335
- spec[:, b, :] = g * spec[:, b, :]
336
-
337
- spec[:, bin_stop:, :] *= 0
338
-
339
- return spec
340
-
341
- def fft_hp_filter(spec, bin_start, bin_stop):
342
- g = 1.0
343
-
344
- for b in range(bin_start, bin_stop, -1):
345
- g -= 1 / (bin_start - bin_stop)
346
- spec[:, b, :] = g * spec[:, b, :]
347
-
348
- spec[:, 0 : bin_stop + 1, :] *= 0
349
-
350
- return spec
351
-
352
- def spectrogram_to_wave_old(spec, hop_length=1024):
353
- if spec.ndim == 2: wave = librosa.istft(spec, hop_length=hop_length)
354
- elif spec.ndim == 3:
355
- spec_left = np.asfortranarray(spec[0])
356
- spec_right = np.asfortranarray(spec[1])
357
-
358
- wave_left = librosa.istft(spec_left, hop_length=hop_length)
359
- wave_right = librosa.istft(spec_right, hop_length=hop_length)
360
- wave = np.asfortranarray([wave_left, wave_right])
361
-
362
- return wave
363
-
364
- def wave_to_spectrogram_old(wave, hop_length, n_fft):
365
- wave_left = np.asfortranarray(wave[0])
366
- wave_right = np.asfortranarray(wave[1])
367
- spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
368
- spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
369
-
370
- return np.asfortranarray([spec_left, spec_right])
371
-
372
- def mirroring(a, spec_m, input_high_end, mp):
373
- if "mirroring" == a:
374
- mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
375
- mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
376
-
377
- return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
378
-
379
- if "mirroring2" == a:
380
- mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
381
- mi = np.multiply(mirror, input_high_end * 1.7)
382
-
383
- return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
384
-
385
- def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
386
- aggr = aggressiveness["value"] * 2
387
-
388
- if aggr != 0:
389
- if is_non_accom_stem:
390
- aggr = 1 - aggr
391
-
392
- if np.any(aggr > 10) or np.any(aggr < -10): print(f"{translations['warnings']}: {aggr}")
393
-
394
- aggr = [aggr, aggr]
395
-
396
- if aggressiveness["aggr_correction"] is not None:
397
- aggr[0] += aggressiveness["aggr_correction"]["left"]
398
- aggr[1] += aggressiveness["aggr_correction"]["right"]
399
-
400
- for ch in range(2):
401
- mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
402
- mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
403
-
404
- return mask
405
-
406
- def stft(wave, nfft, hl):
407
- wave_left = np.asfortranarray(wave[0])
408
- wave_right = np.asfortranarray(wave[1])
409
- spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
410
- spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
411
- spec = np.asfortranarray([spec_left, spec_right])
412
-
413
- return spec
414
-
415
- def istft(spec, hl):
416
- spec_left = np.asfortranarray(spec[0])
417
- spec_right = np.asfortranarray(spec[1])
418
- wave_left = librosa.istft(spec_left, hop_length=hl)
419
- wave_right = librosa.istft(spec_right, hop_length=hl)
420
- wave = np.asfortranarray([wave_left, wave_right])
421
-
422
- return wave
423
-
424
- def spec_effects(wave, algorithm="Default", value=None):
425
- if np.isnan(wave).any() or np.isinf(wave).any(): print(f"{translations['warnings_2']}: {wave.shape}")
426
-
427
- spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
428
-
429
- if algorithm == "Min_Mag":
430
- v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
431
- wave = istft(v_spec_m, 1024)
432
- elif algorithm == "Max_Mag":
433
- v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
434
- wave = istft(v_spec_m, 1024)
435
- elif algorithm == "Default": wave = (wave[1] * value) + (wave[0] * (1 - value))
436
- elif algorithm == "Invert_p":
437
- X_mag = np.abs(spec[0])
438
- y_mag = np.abs(spec[1])
439
-
440
- max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
441
- v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
442
-
443
- wave = istft(v_spec, 1024)
444
-
445
- return wave
446
-
447
- def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
448
- wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
449
- if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
450
-
451
- return wave
452
-
453
- def wave_to_spectrogram_no_mp(wave):
454
-
455
- spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
456
-
457
- if spec.ndim == 1: spec = np.asfortranarray([spec, spec])
458
-
459
- return spec
460
-
461
- def invert_audio(specs, invert_p=True):
462
-
463
- ln = min([specs[0].shape[2], specs[1].shape[2]])
464
- specs[0] = specs[0][:, :, :ln]
465
- specs[1] = specs[1][:, :, :ln]
466
-
467
- if invert_p:
468
- X_mag = np.abs(specs[0])
469
- y_mag = np.abs(specs[1])
470
- max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
471
- v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
472
- else:
473
- specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
474
- v_spec = specs[0] - specs[1]
475
-
476
- return v_spec
477
-
478
- def invert_stem(mixture, stem):
479
- mixture = wave_to_spectrogram_no_mp(mixture)
480
- stem = wave_to_spectrogram_no_mp(stem)
481
- output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
482
-
483
- return -output.T
484
-
485
- def ensembling(a, inputs, is_wavs=False):
486
- for i in range(1, len(inputs)):
487
- if i == 1: input = inputs[0]
488
-
489
- if is_wavs:
490
- ln = min([input.shape[1], inputs[i].shape[1]])
491
- input = input[:, :ln]
492
- inputs[i] = inputs[i][:, :ln]
493
- else:
494
- ln = min([input.shape[2], inputs[i].shape[2]])
495
- input = input[:, :, :ln]
496
- inputs[i] = inputs[i][:, :, :ln]
497
-
498
- if MIN_SPEC == a: input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
499
- if MAX_SPEC == a: input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
500
-
501
- return input
502
-
503
- def ensemble_for_align(waves):
504
-
505
- specs = []
506
-
507
- for wav in waves:
508
- spec = wave_to_spectrogram_no_mp(wav.T)
509
- specs.append(spec)
510
-
511
- wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
512
- wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
513
-
514
- return wav_aligned
515
-
516
- def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
517
- wavs_ = []
518
-
519
- if algorithm == AVERAGE:
520
- output = average_audio(audio_input)
521
- samplerate = 44100
522
- else:
523
- specs = []
524
-
525
- for i in range(len(audio_input)):
526
- wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
527
- wavs_.append(wave)
528
- spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
529
- specs.append(spec)
530
-
531
- wave_shapes = [w.shape[1] for w in wavs_]
532
- target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
533
-
534
- output = ensembling(algorithm, specs, is_wavs=True) if is_wave else spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
535
- output = to_shape(output, target_shape.shape)
536
-
537
- sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
538
-
539
- def to_shape(x, target_shape):
540
- padding_list = []
541
-
542
- for x_dim, target_dim in zip(x.shape, target_shape):
543
- pad_value = target_dim - x_dim
544
- pad_tuple = (0, pad_value)
545
- padding_list.append(pad_tuple)
546
-
547
- return np.pad(x, tuple(padding_list), mode="constant")
548
-
549
- def to_shape_minimize(x: np.ndarray, target_shape):
550
- padding_list = []
551
-
552
- for x_dim, target_dim in zip(x.shape, target_shape):
553
- pad_value = target_dim - x_dim
554
- pad_tuple = (0, pad_value)
555
- padding_list.append(pad_tuple)
556
-
557
- return np.pad(x, tuple(padding_list), mode="constant")
558
-
559
- def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
560
- if len(audio.shape) == 2:
561
- channel = np.argmax(np.sum(np.abs(audio), axis=1))
562
- audio = audio[channel]
563
-
564
- for i in range(0, len(audio), frame_length):
565
- if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold: return (i / sr) * 1000
566
-
567
- return (len(audio) / sr) * 1000
568
-
569
- def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
570
- def find_silence_end(audio):
571
- if len(audio.shape) == 2:
572
- channel = np.argmax(np.sum(np.abs(audio), axis=1))
573
- audio_mono = audio[channel]
574
- else: audio_mono = audio
575
-
576
- for i in range(0, len(audio_mono), frame_length):
577
- if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold: return i
578
-
579
- return len(audio_mono)
580
-
581
- ref_silence_end = find_silence_end(reference_audio)
582
- target_silence_end = find_silence_end(target_audio)
583
- silence_difference = ref_silence_end - target_silence_end
584
-
585
- try:
586
- ref_silence_end_p = (ref_silence_end / 44100) * 1000
587
- target_silence_end_p = (target_silence_end / 44100) * 1000
588
- silence_difference_p = ref_silence_end_p - target_silence_end_p
589
-
590
- print("im lặng khác biệt: ", silence_difference_p)
591
- except Exception as e:
592
- pass
593
-
594
- if silence_difference > 0:
595
- silence_to_add = np.zeros((target_audio.shape[0], silence_difference))if len(target_audio.shape) == 2 else np.zeros(silence_difference)
596
-
597
- return np.hstack((silence_to_add, target_audio))
598
- elif silence_difference < 0:
599
- if len(target_audio.shape) == 2: return target_audio[:, -silence_difference:]
600
- else: return target_audio[-silence_difference:]
601
- else: return target_audio
602
-
603
- def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
604
-
605
- if is_swap: array_1, array_2 = array_1.T, array_2.T
606
-
607
- if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]]
608
- elif array_1.shape[1] < array_2.shape[1]:
609
- padding = array_2.shape[1] - array_1.shape[1]
610
- array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
611
-
612
- if is_swap: array_1, array_2 = array_1.T, array_2.T
613
-
614
- return array_1
615
-
616
- def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
617
- if len(array_1) > len(array_2): array_1 = array_1[: len(array_2)]
618
- elif len(array_1) < len(array_2):
619
- padding = len(array_2) - len(array_1)
620
- array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
621
-
622
- return array_1
623
-
624
- def change_pitch_semitones(y, sr, semitone_shift):
625
- factor = 2 ** (semitone_shift / 12)
626
- y_pitch_tuned = []
627
-
628
- for y_channel in y:
629
- y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
630
-
631
- y_pitch_tuned = np.array(y_pitch_tuned)
632
- new_sr = sr * factor
633
-
634
- return y_pitch_tuned, new_sr
635
-
636
- def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
637
- wav, sr = librosa.load(audio_file, sr=44100, mono=False)
638
-
639
- if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
640
-
641
- if not is_time_correction: wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
642
- else:
643
- if is_pitch:
644
- wav_1 = pitch_shift(wav[0], sr, rate, rbargs=None)
645
- wav_2 = pitch_shift(wav[1], sr, rate, rbargs=None)
646
- else:
647
- wav_1 = time_stretch(wav[0], sr, rate, rbargs=None)
648
- wav_2 = time_stretch(wav[1], sr, rate, rbargs=None)
649
-
650
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
651
- if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
652
-
653
- wav_mix = np.asfortranarray([wav_1, wav_2])
654
-
655
- sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
656
- save_format(export_path)
657
-
658
-
659
- def average_audio(audio):
660
- waves = []
661
- wave_shapes = []
662
- final_waves = []
663
-
664
- for i in range(len(audio)):
665
- wave = librosa.load(audio[i], sr=44100, mono=False)
666
- waves.append(wave[0])
667
- wave_shapes.append(wave[0].shape[1])
668
-
669
- wave_shapes_index = wave_shapes.index(max(wave_shapes))
670
- target_shape = waves[wave_shapes_index]
671
-
672
- waves.pop(wave_shapes_index)
673
- final_waves.append(target_shape)
674
-
675
- for n_array in waves:
676
- wav_target = to_shape(n_array, target_shape.shape)
677
- final_waves.append(wav_target)
678
-
679
- waves = sum(final_waves)
680
- waves = waves / len(audio)
681
-
682
- return waves
683
-
684
- def average_dual_sources(wav_1, wav_2, value):
685
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
686
- if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
687
-
688
- wave = (wav_1 * value) + (wav_2 * (1 - value))
689
-
690
- return wave
691
-
692
- def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
693
- if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
694
-
695
- if wav_1.shape < wav_2.shape:
696
- ln = min([wav_1.shape[1], wav_2.shape[1]])
697
- wav_2 = wav_2[:, :ln]
698
-
699
- ln = min([wav_1.shape[1], wav_2.shape[1]])
700
- wav_1 = wav_1[:, :ln]
701
- wav_2 = wav_2[:, :ln]
702
-
703
- return wav_2
704
-
705
- def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
706
- if wav_1_shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1_shape)
707
-
708
- return wav_2
709
-
710
- def combine_arrarys(audio_sources, is_swap=False):
711
- source = np.zeros_like(max(audio_sources, key=np.size))
712
-
713
- for v in audio_sources:
714
- v = match_array_shapes(v, source, is_swap=is_swap)
715
- source += v
716
-
717
- return source
718
-
719
- def combine_audio(paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
720
- source = combine_arrarys([load_audio(i) for i in paths])
721
- save_path = f"{audio_file_base}_combined.wav"
722
-
723
- sf.write(save_path, source.T, 44100, subtype=wav_type_set)
724
- save_format(save_path)
725
-
726
- def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
727
- inst_source = inst_source * (1 - reduction_rate)
728
- mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
729
-
730
- return mix_reduced
731
-
732
- def organize_inputs(inputs):
733
- input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
734
-
735
- for i in inputs:
736
- if i.endswith("_(Vocals).wav"): input_list["reference"] = i
737
- elif "_RVC_" in i: input_list["target"] = i
738
- elif i.endswith("reverbed_stem.wav"): input_list["reverb"] = i
739
- elif i.endswith("_(Instrumental).wav"): input_list["inst"] = i
740
-
741
- return input_list
742
-
743
- def check_if_phase_inverted(wav1, wav2, is_mono=False):
744
- if not is_mono:
745
- wav1 = np.mean(wav1, axis=0)
746
- wav2 = np.mean(wav2, axis=0)
747
-
748
- correlation = np.corrcoef(wav1[:1000], wav2[:1000])
749
-
750
- return correlation[0, 1] < 0
751
-
752
- def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_save_aligned, command_Text, save_format, align_window: list, align_intro_val: list, db_analysis: tuple, set_progress_bar, phase_option, phase_shifts, is_match_silence, is_spec_match):
753
- global progress_value
754
- progress_value = 0
755
- is_mono = False
756
-
757
- def get_diff(a, b):
758
- corr = np.correlate(a, b, "full")
759
- diff = corr.argmax() - (b.shape[0] - 1)
760
-
761
- return diff
762
-
763
- def progress_bar(length):
764
- global progress_value
765
-
766
- progress_value += 1
767
-
768
- if (0.90 / length * progress_value) >= 0.9: length = progress_value + 1
769
-
770
- set_progress_bar(0.1, (0.9 / length * progress_value))
771
-
772
- if file1.endswith(".mp3") and is_macos:
773
- length1 = rerun_mp3(file1)
774
- wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
775
- else: wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
776
-
777
- if file2.endswith(".mp3") and is_macos:
778
- length2 = rerun_mp3(file2)
779
- wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
780
- else: wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
781
-
782
- if wav1.ndim == 1 and wav2.ndim == 1: is_mono = True
783
- elif wav1.ndim == 1: wav1 = np.asfortranarray([wav1, wav1])
784
- elif wav2.ndim == 1: wav2 = np.asfortranarray([wav2, wav2])
785
-
786
- if phase_option == AUTO_PHASE:
787
- if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): wav2 = -wav2
788
- elif phase_option == POSITIVE_PHASE: wav2 = +wav2
789
- elif phase_option == NEGATIVE_PHASE: wav2 = -wav2
790
-
791
- if is_match_silence: wav2 = adjust_leading_silence(wav2, wav1)
792
-
793
- wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
794
- wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
795
-
796
- if not is_mono:
797
- wav1 = wav1.transpose()
798
- wav2 = wav2.transpose()
799
-
800
- wav2_org = wav2.copy()
801
-
802
- command_Text(translations["process_file"])
803
- seconds_length = min(wav1_length, wav2_length)
804
-
805
- wav2_aligned_sources = []
806
-
807
- for sec_len in align_intro_val:
808
- sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
809
- index = sr1 * sec_seg
810
-
811
- if is_mono:
812
- samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
813
- diff = get_diff(samp1, samp2)
814
- else:
815
- index = sr1 * sec_seg
816
- samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
817
- samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
818
- diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
819
-
820
- if diff > 0:
821
- zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
822
- wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
823
- elif diff < 0: wav2_aligned = wav2_org[-diff:]
824
- else: wav2_aligned = wav2_org
825
-
826
- if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources): wav2_aligned_sources.append(wav2_aligned)
827
-
828
- unique_sources = len(wav2_aligned_sources)
829
-
830
- sub_mapper_big_mapper = {}
831
-
832
- for s in wav2_aligned_sources:
833
- wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
834
-
835
- if align_window:
836
- wav_sub = time_correction(wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts)
837
- wav_sub_size = np.abs(wav_sub).mean()
838
-
839
- sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
840
- else:
841
- wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
842
- db_range = db_analysis[1]
843
-
844
- for db_adjustment in db_range:
845
- s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
846
-
847
- wav_sub = wav1 - s_adjusted
848
- wav_sub_size = np.abs(wav_sub).mean()
849
-
850
- sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
851
-
852
- sub_mapper_value_list = list(sub_mapper_big_mapper.values())
853
-
854
- wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) if is_spec_match and len(sub_mapper_value_list) >= 2 else ensemble_wav(list(sub_mapper_big_mapper.values()))
855
-
856
- wav_sub = np.clip(wav_sub, -1, +1)
857
-
858
- command_Text(translations["save_instruments"])
859
-
860
- if is_save_aligned or is_spec_match:
861
- wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
862
- wav2_aligned = wav1 - wav_sub
863
-
864
- if is_spec_match:
865
- if wav1.ndim == 1 and wav2.ndim == 1:
866
- wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
867
- wav1 = np.asfortranarray([wav1, wav1]).T
868
-
869
- wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
870
- wav_sub = wav1 - wav2_aligned
871
-
872
- if is_save_aligned:
873
- sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
874
- save_format(file2_aligned)
875
-
876
- sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
877
- save_format(file_subtracted)
878
-
879
- def phase_shift_hilbert(signal, degree):
880
- analytic_signal = hilbert(signal)
881
-
882
- return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
883
-
884
- def get_phase_shifted_tracks(track, phase_shift):
885
- if phase_shift == 180: return [track, -track]
886
-
887
- step = phase_shift
888
- end = 180 - (180 % step) if 180 % step == 0 else 181
889
- phase_range = range(step, end, step)
890
-
891
- flipped_list = [track, -track]
892
-
893
- for i in phase_range:
894
- flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
895
-
896
- return flipped_list
897
-
898
- def time_correction(mix: np.ndarray, instrumental: np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
899
- def align_tracks(track1, track2):
900
- shifted_tracks = {}
901
-
902
- track2 = track2 * np.power(10, db_analysis[0] / 20)
903
- db_range = db_analysis[1]
904
-
905
- track2_flipped = [track2] if phase_shifts == 190 else get_phase_shifted_tracks(track2, phase_shifts)
906
-
907
- for db_adjustment in db_range:
908
- for t in track2_flipped:
909
- track2_adjusted = t * (10 ** (db_adjustment / 20))
910
- corr = correlate(track1, track2_adjusted)
911
- delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
912
- track2_shifted = np.roll(track2_adjusted, shift=delay)
913
-
914
- track2_shifted_sub = track1 - track2_shifted
915
- mean_abs_value = np.abs(track2_shifted_sub).mean()
916
-
917
- shifted_tracks[mean_abs_value] = track2_shifted
918
-
919
- return shifted_tracks[min(shifted_tracks.keys())]
920
-
921
- assert mix.shape == instrumental.shape, translations["assert"].format(mixshape=mix.shape, instrumentalshape=instrumental.shape)
922
-
923
- seconds_length = seconds_length // 2
924
-
925
- sub_mapper = {}
926
-
927
- progress_update_interval = 120
928
- total_iterations = 0
929
-
930
- if len(align_window) > 2: progress_update_interval = 320
931
-
932
- for secs in align_window:
933
- step = secs / 2
934
- window_size = int(sr * secs)
935
- step_size = int(sr * step)
936
-
937
- if len(mix.shape) == 1:
938
- total_mono = (len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources
939
- total_iterations += total_mono
940
- else:
941
- total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
942
- total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
943
- total_iterations += total_stereo
944
-
945
- for secs in align_window:
946
- sub = np.zeros_like(mix)
947
- divider = np.zeros_like(mix)
948
- step = secs / 2
949
- window_size = int(sr * secs)
950
- step_size = int(sr * step)
951
- window = np.hanning(window_size)
952
-
953
- if len(mix.shape) == 1:
954
- counter = 0
955
-
956
- for i in range(0, len(mix) - window_size, step_size):
957
- counter += 1
958
-
959
- if counter % progress_update_interval == 0: progress_bar(total_iterations)
960
-
961
- window_mix = mix[i : i + window_size] * window
962
- window_instrumental = instrumental[i : i + window_size] * window
963
- window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
964
-
965
- sub[i : i + window_size] += window_mix - window_instrumental_aligned
966
- divider[i : i + window_size] += window
967
- else:
968
- counter = 0
969
-
970
- for ch in range(mix.shape[1]):
971
- for i in range(0, len(mix[:, ch]) - window_size, step_size):
972
- counter += 1
973
-
974
- if counter % progress_update_interval == 0: progress_bar(total_iterations)
975
-
976
- window_mix = mix[i : i + window_size, ch] * window
977
- window_instrumental = instrumental[i : i + window_size, ch] * window
978
- window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
979
-
980
- sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
981
- divider[i : i + window_size, ch] += window
982
-
983
- sub = np.where(divider > 1e-6, sub / divider, sub)
984
- sub_size = np.abs(sub).mean()
985
- sub_mapper = {**sub_mapper, **{sub_size: sub}}
986
-
987
- sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
988
-
989
- return sub
990
-
991
- def ensemble_wav(waveforms, split_size=240):
992
- waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
993
-
994
- final_waveform = []
995
-
996
- for third_idx in range(split_size):
997
- means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))]
998
-
999
- min_index = np.argmin(means)
1000
-
1001
- final_waveform.append(waveform_thirds[min_index][third_idx])
1002
-
1003
- final_waveform = np.concatenate(final_waveform)
1004
-
1005
- return final_waveform
1006
-
1007
- def ensemble_wav_min(waveforms):
1008
- for i in range(1, len(waveforms)):
1009
- if i == 1: wave = waveforms[0]
1010
-
1011
- ln = min(len(wave), len(waveforms[i]))
1012
- wave = wave[:ln]
1013
- waveforms[i] = waveforms[i][:ln]
1014
- wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
1015
-
1016
- return wave
1017
-
1018
- def align_audio_test(wav1, wav2, sr1=44100):
1019
- def get_diff(a, b):
1020
- corr = np.correlate(a, b, "full")
1021
- diff = corr.argmax() - (b.shape[0] - 1)
1022
- return diff
1023
-
1024
- wav1 = wav1.transpose()
1025
- wav2 = wav2.transpose()
1026
- wav2_org = wav2.copy()
1027
-
1028
- index = sr1
1029
- samp1 = wav1[index : index + sr1, 0]
1030
- samp2 = wav2[index : index + sr1, 0]
1031
- diff = get_diff(samp1, samp2)
1032
-
1033
- if diff > 0: wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
1034
- elif diff < 0: wav2_aligned = wav2_org[-diff:]
1035
- else: wav2_aligned = wav2_org
1036
-
1037
- return wav2_aligned
1038
-
1039
- def load_audio(audio_file):
1040
- wav, sr = librosa.load(audio_file, sr=44100, mono=False)
1041
- if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
1042
-
1043
- return wav
1044
-
1045
- def rerun_mp3(audio_file):
1046
- with audioread.audio_open(audio_file) as f:
1047
- track_length = int(f.duration)
1048
-
1049
- return track_length
1050
-
1051
- def __rubberband(y, sr, **kwargs):
1052
- assert sr > 0
1053
-
1054
- fd, infile = tempfile.mkstemp(suffix='.wav')
1055
- os.close(fd)
1056
- fd, outfile = tempfile.mkstemp(suffix='.wav')
1057
- os.close(fd)
1058
-
1059
- sf.write(infile, y, sr)
1060
-
1061
- try:
1062
- arguments = [os.path.join(BASE_PATH_RUB, 'rubberband'), '-q']
1063
-
1064
- for key, value in six.iteritems(kwargs):
1065
- arguments.append(str(key))
1066
- arguments.append(str(value))
1067
-
1068
- arguments.extend([infile, outfile])
1069
-
1070
- subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
1071
-
1072
- y_out, _ = sf.read(outfile, always_2d=True)
1073
-
1074
- if y.ndim == 1: y_out = np.squeeze(y_out)
1075
- except OSError as exc:
1076
- six.raise_from(RuntimeError(translations["rubberband"]), exc)
1077
- finally:
1078
- os.unlink(infile)
1079
- os.unlink(outfile)
1080
-
1081
- return y_out
1082
-
1083
- def time_stretch(y, sr, rate, rbargs=None):
1084
- if rate <= 0: raise ValueError(translations["rate"])
1085
-
1086
- if rate == 1.0: return y
1087
- if rbargs is None: rbargs = dict()
1088
-
1089
- rbargs.setdefault('--tempo', rate)
1090
-
1091
- return __rubberband(y, sr, **rbargs)
1092
-
1093
- def pitch_shift(y, sr, n_steps, rbargs=None):
1094
-
1095
- if n_steps == 0: return y
1096
- if rbargs is None: rbargs = dict()
1097
-
1098
- rbargs.setdefault('--pitch', n_steps)
1099
-
1100
- return __rubberband(y, sr, **rbargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/gdown.py DELETED
@@ -1,230 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import six
5
- import json
6
- import tqdm
7
- import shutil
8
- import tempfile
9
- import requests
10
- import warnings
11
- import textwrap
12
-
13
- from time import sleep, time
14
- from urllib.parse import urlparse, parse_qs, unquote
15
-
16
- now_dir = os.getcwd()
17
- sys.path.append(now_dir)
18
-
19
- from main.configs.config import Config
20
- translations = Config().translations
21
-
22
- CHUNK_SIZE = 512 * 1024
23
- HOME = os.path.expanduser("~")
24
-
25
- def indent(text, prefix):
26
- return "".join((prefix + line if line.strip() else line) for line in text.splitlines(True))
27
-
28
-
29
- def parse_url(url, warning=True):
30
- parsed = urlparse(url)
31
- is_download_link = parsed.path.endswith("/uc")
32
-
33
- if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
34
-
35
- file_id = parse_qs(parsed.query).get("id", [None])[0]
36
-
37
- if file_id is None:
38
- for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
39
- match = re.match(pattern, parsed.path)
40
-
41
- if match:
42
- file_id = match.group(1)
43
- break
44
-
45
- if warning and not is_download_link:
46
- warnings.warn(translations["gdown_warning"].format(file_id=file_id))
47
-
48
- return file_id, is_download_link
49
-
50
-
51
- def get_url_from_gdrive_confirmation(contents):
52
- for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
53
- match = re.search(pattern, contents)
54
-
55
- if match:
56
- url = match.group(1)
57
-
58
- if pattern == r'href="/open\?id=([^"]+)"': url = ("https://drive.usercontent.google.com/download?id=" + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
59
- elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
60
- else: url = "https://docs.google.com" + url.replace("&", "&")
61
-
62
- return url
63
-
64
- match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
65
-
66
- if match:
67
- error = match.group(1)
68
- raise Exception(error)
69
-
70
- raise Exception(translations["gdown_error"])
71
-
72
-
73
- def _get_session(proxy, use_cookies, return_cookies_file=False):
74
- sess = requests.session()
75
- sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
76
-
77
- if proxy is not None:
78
- sess.proxies = {"http": proxy, "https": proxy}
79
- print("Using proxy:", proxy, file=sys.stderr)
80
-
81
- cookies_file = os.path.join(HOME, ".cache/gdown/cookies.json")
82
-
83
- if os.path.exists(cookies_file) and use_cookies:
84
- with open(cookies_file) as f:
85
- cookies = json.load(f)
86
-
87
- for k, v in cookies:
88
- sess.cookies[k] = v
89
-
90
- return (sess, cookies_file) if return_cookies_file else sess
91
-
92
-
93
- def gdown_download(url=None, output=None, output_dir=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=True, resume=False, format=None):
94
- if not (id is None) ^ (url is None): raise ValueError(translations["gdown_value_error"])
95
- if id is not None: url = f"https://drive.google.com/uc?id={id}"
96
-
97
- url_origin = url
98
-
99
- sess, cookies_file = _get_session(proxy=proxy, use_cookies=use_cookies, return_cookies_file=True)
100
-
101
- gdrive_file_id, is_gdrive_download_link = parse_url(url, warning=not fuzzy)
102
-
103
-
104
- if fuzzy and gdrive_file_id:
105
- url = f"https://drive.google.com/uc?id={gdrive_file_id}"
106
- url_origin = url
107
- is_gdrive_download_link = True
108
-
109
- while 1:
110
- res = sess.get(url, stream=True, verify=verify)
111
-
112
- if url == url_origin and res.status_code == 500:
113
- url = f"https://drive.google.com/open?id={gdrive_file_id}"
114
- continue
115
-
116
- if res.headers["Content-Type"].startswith("text/html"):
117
- title = re.search("<title>(.+)</title>", res.text)
118
-
119
- if title:
120
- title = title.group(1)
121
- if title.endswith(" - Google Docs"):
122
- url = f"https://docs.google.com/document/d/{gdrive_file_id}/export?format={'docx' if format is None else format}"
123
- continue
124
- if title.endswith(" - Google Sheets"):
125
- url = f"https://docs.google.com/spreadsheets/d/{gdrive_file_id}/export?format={'xlsx' if format is None else format}"
126
- continue
127
- if title.endswith(" - Google Slides"):
128
- url = f"https://docs.google.com/presentation/d/{gdrive_file_id}/export?format={'pptx' if format is None else format}"
129
- continue
130
- elif ("Content-Disposition" in res.headers and res.headers["Content-Disposition"].endswith("pptx") and format not in (None, "pptx")):
131
- url = f"https://docs.google.com/presentation/d/{gdrive_file_id}/export?format={'pptx' if format is None else format}"
132
- continue
133
-
134
- if use_cookies:
135
- os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
136
-
137
- with open(cookies_file, "w") as f:
138
- cookies = [(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")]
139
- json.dump(cookies, f, indent=2)
140
-
141
- if "Content-Disposition" in res.headers: break
142
- if not (gdrive_file_id and is_gdrive_download_link): break
143
-
144
-
145
- try:
146
- url = get_url_from_gdrive_confirmation(res.text)
147
- except Exception as e:
148
- error = indent("\n".join(textwrap.wrap(str(e))), prefix="\t")
149
- raise Exception(translations["gdown_error_2"].format(error=error, url_origin=url_origin))
150
-
151
- if gdrive_file_id and is_gdrive_download_link:
152
- content_disposition = unquote(res.headers["Content-Disposition"])
153
- filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1)
154
- filename_from_url = filename_from_url.replace(os.path.sep, "_")
155
- else: filename_from_url = os.path.basename(url)
156
-
157
- output = output or filename_from_url
158
- output_is_path = isinstance(output, six.string_types)
159
-
160
- if output_is_path and output.endswith(os.path.sep):
161
- os.makedirs(output, exist_ok=True)
162
- output = os.path.join(output, filename_from_url)
163
-
164
- if output_is_path:
165
- temp_dir = os.path.dirname(output) or "."
166
- prefix = os.path.basename(output)
167
- existing_tmp_files = [os.path.join(temp_dir, file) for file in os.listdir(temp_dir) if file.startswith(prefix)]
168
-
169
- if resume and existing_tmp_files:
170
- if len(existing_tmp_files) > 1:
171
- print(translations["temps"], file=sys.stderr)
172
-
173
- for file in existing_tmp_files:
174
- print(f"\t{file}", file=sys.stderr)
175
-
176
- print(translations["del_all_temps"], file=sys.stderr)
177
- return
178
-
179
- tmp_file = existing_tmp_files[0]
180
- else:
181
- resume = False
182
- tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=prefix, dir=temp_dir)
183
-
184
- f = open(tmp_file, "ab")
185
- else:
186
- tmp_file = None
187
- f = output
188
-
189
-
190
- if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=verify)
191
-
192
- if not quiet:
193
- if resume: print(translations["continue"], tmp_file, file=sys.stderr)
194
-
195
- print(translations["to"], os.path.abspath(output) if output_is_path else output, file=sys.stderr)
196
-
197
- try:
198
- if not quiet: pbar = tqdm.tqdm(total=int(res.headers.get("Content-Length", 0)))
199
-
200
- t_start = time()
201
-
202
- for chunk in res.iter_content(chunk_size=CHUNK_SIZE):
203
- f.write(chunk)
204
-
205
- if not quiet: pbar.update(len(chunk))
206
-
207
- if speed is not None:
208
- elapsed_time_expected = 1.0 * pbar.n / speed
209
- elapsed_time = time() - t_start
210
-
211
- if elapsed_time < elapsed_time_expected: sleep(elapsed_time_expected - elapsed_time)
212
-
213
- if not quiet: pbar.close()
214
-
215
- if tmp_file:
216
- f.close()
217
-
218
- if output_dir is not None:
219
- output_file = os.path.join(output_dir, output)
220
- if os.path.exists(output_file): os.remove(output_file)
221
-
222
- shutil.move(tmp_file, output_file)
223
- else:
224
- if os.path.exists(output): os.remove(output)
225
-
226
- shutil.move(tmp_file, output)
227
- finally:
228
- sess.close()
229
-
230
- return output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/mediafire.py DELETED
@@ -1,30 +0,0 @@
1
- import os
2
- import sys
3
- import requests
4
- from bs4 import BeautifulSoup
5
-
6
- def Mediafire_Download(url, output=None, filename=None):
7
- if not filename: filename = url.split('/')[-2]
8
- if not output: output = os.path.dirname(os.path.realpath(__file__))
9
- output_file = os.path.join(output, filename)
10
-
11
- sess = requests.session()
12
- sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
13
-
14
- try:
15
- with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
16
- r.raise_for_status()
17
- with open(output_file, "wb") as f:
18
- total_length = r.headers.get('content-length')
19
- total_length = int(total_length)
20
- download_progress = 0
21
-
22
- for chunk in r.iter_content(chunk_size=1024):
23
- download_progress += len(chunk)
24
- f.write(chunk)
25
- sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
26
- sys.stdout.flush()
27
- sys.stdout.write("\n")
28
- return output_file
29
- except Exception as e:
30
- raise RuntimeError(e)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/meganz.py DELETED
@@ -1,180 +0,0 @@
1
- import os
2
- import re
3
- import sys
4
- import json
5
- import tqdm
6
- import codecs
7
- import random
8
- import base64
9
- import struct
10
- import shutil
11
- import requests
12
- import tempfile
13
-
14
- from Crypto.Cipher import AES
15
- from Crypto.Util import Counter
16
- from tenacity import retry, wait_exponential, retry_if_exception_type
17
-
18
-
19
- now_dir = os.getcwd()
20
- sys.path.append(now_dir)
21
-
22
- from main.configs.config import Config
23
- translations = Config().translations
24
-
25
-
26
- def makebyte(x):
27
- return codecs.latin_1_encode(x)[0]
28
-
29
-
30
- def a32_to_str(a):
31
- return struct.pack('>%dI' % len(a), *a)
32
-
33
-
34
- def get_chunks(size):
35
- p = 0
36
- s = 0x20000
37
-
38
- while p + s < size:
39
- yield (p, s)
40
- p += s
41
-
42
- if s < 0x100000: s += 0x20000
43
-
44
- yield (p, size - p)
45
-
46
-
47
- def decrypt_attr(attr, key):
48
- attr = AES.new(a32_to_str(key), AES.MODE_CBC, makebyte('\0' * 16)).decrypt(attr)
49
- attr = codecs.latin_1_decode(attr)[0]
50
- attr = attr.rstrip('\0')
51
-
52
- return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
53
-
54
-
55
- @retry(retry=retry_if_exception_type(RuntimeError), wait=wait_exponential(multiplier=2, min=2, max=60))
56
- def _api_request(data):
57
- sequence_num = random.randint(0, 0xFFFFFFFF)
58
- params = {'id': sequence_num}
59
- sequence_num += 1
60
-
61
- if not isinstance(data, list): data = [data]
62
-
63
- json_resp = json.loads(requests.post(f'https://g.api.mega.co.nz/cs', params=params, data=json.dumps(data), timeout=160).text)
64
-
65
-
66
- try:
67
- if isinstance(json_resp, list): int_resp = json_resp[0] if isinstance(json_resp[0], int) else None
68
- elif isinstance(json_resp, int): int_resp = json_resp
69
- except IndexError:
70
- int_resp = None
71
-
72
- if int_resp is not None:
73
- if int_resp == 0: return int_resp
74
- if int_resp == -3: raise RuntimeError('int_resp==-3')
75
-
76
- raise Exception(int_resp)
77
-
78
- return json_resp[0]
79
-
80
-
81
- def base64_url_decode(data):
82
- data += '=='[(2 - len(data) * 3) % 4:]
83
-
84
- for search, replace in (('-', '+'), ('_', '/'), (',', '')):
85
- data = data.replace(search, replace)
86
-
87
- return base64.b64decode(data)
88
-
89
-
90
- def str_to_a32(b):
91
- if isinstance(b, str): b = makebyte(b)
92
- if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
93
-
94
- return struct.unpack('>%dI' % (len(b) / 4), b)
95
-
96
-
97
- def mega_download_file(file_handle, file_key, dest_path=None, dest_filename=None, file=None):
98
- if file is None:
99
- file_key = str_to_a32(base64_url_decode(file_key))
100
- file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
101
-
102
- k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
103
- iv = file_key[4:6] + (0, 0)
104
- meta_mac = file_key[6:8]
105
- else:
106
- file_data = _api_request({'a': 'g', 'g': 1, 'n': file['h']})
107
- k = file['k']
108
- iv = file['iv']
109
- meta_mac = file['meta_mac']
110
-
111
- if 'g' not in file_data: raise Exception(translations["file_not_access"])
112
-
113
- file_size = file_data['s']
114
-
115
- attribs = base64_url_decode(file_data['at'])
116
- attribs = decrypt_attr(attribs, k)
117
-
118
- file_name = dest_filename if dest_filename is not None else attribs['n']
119
-
120
- input_file = requests.get(file_data['g'], stream=True).raw
121
-
122
- if dest_path is None: dest_path = ''
123
- else: dest_path += '/'
124
-
125
- temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
126
-
127
- k_str = a32_to_str(k)
128
-
129
- counter = Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64)
130
- aes = AES.new(k_str, AES.MODE_CTR, counter=counter)
131
-
132
- mac_str = b'\0' * 16
133
- mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
134
-
135
- iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
136
-
137
- pbar = tqdm.tqdm(total=file_size)
138
-
139
- for _, chunk_size in get_chunks(file_size):
140
- chunk = input_file.read(chunk_size)
141
- chunk = aes.decrypt(chunk)
142
- temp_output_file.write(chunk)
143
-
144
- pbar.update(len(chunk))
145
-
146
- encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
147
-
148
- for i in range(0, len(chunk)-16, 16):
149
- block = chunk[i:i + 16]
150
- encryptor.encrypt(block)
151
-
152
- if file_size > 16: i += 16
153
- else: i = 0
154
-
155
- block = chunk[i:i + 16]
156
- if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
157
-
158
- mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
159
-
160
- file_mac = str_to_a32(mac_str)
161
- temp_output_file.close()
162
-
163
- if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != meta_mac: raise ValueError(translations["mac_not_match"])
164
-
165
- file_path = os.path.join(dest_path, file_name)
166
- if os.path.exists(file_path): os.remove(file_path)
167
-
168
- shutil.move(temp_output_file.name, file_path)
169
-
170
-
171
- def mega_download_url(url, dest_path=None, dest_filename=None):
172
- if '/file/' in url:
173
- url = url.replace(' ', '')
174
- file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
175
-
176
- path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
177
- elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
178
- else: raise Exception(translations["missing_url"])
179
-
180
- return mega_download_file(file_handle=path[0], file_key=path[1], dest_path=dest_path, dest_filename=dest_filename)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
main/tools/pixeldrain.py DELETED
@@ -1,19 +0,0 @@
1
- import os
2
- import requests
3
-
4
-
5
- def pixeldrain(url, output_dir):
6
- try:
7
- file_id = url.split("pixeldrain.com/u/")[1]
8
- response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
9
-
10
- if response.status_code == 200:
11
- file_name = (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";'))
12
- file_path = os.path.join(output_dir, file_name)
13
-
14
- with open(file_path, "wb") as newfile:
15
- newfile.write(response.content)
16
- return file_path
17
- else: return None
18
- except Exception as e:
19
- raise RuntimeError(e)