Delete main
Browse files- main/app/app.py +0 -0
- main/app/clean.py +0 -40
- main/app/run_tensorboard.py +0 -30
- main/app/sync.py +0 -80
- main/configs/config.json +0 -229
- main/configs/config.py +0 -120
- main/configs/v1/32000.json +0 -82
- main/configs/v1/40000.json +0 -80
- main/configs/v1/48000.json +0 -82
- main/configs/v2/32000.json +0 -76
- main/configs/v2/40000.json +0 -76
- main/configs/v2/48000.json +0 -76
- main/inference/audio_effects.py +0 -170
- main/inference/convert.py +0 -1060
- main/inference/create_dataset.py +0 -370
- main/inference/create_index.py +0 -120
- main/inference/extract.py +0 -450
- main/inference/preprocess.py +0 -360
- main/inference/separator_music.py +0 -400
- main/inference/train.py +0 -1600
- main/library/algorithm/commons.py +0 -100
- main/library/algorithm/modules.py +0 -80
- main/library/algorithm/residuals.py +0 -170
- main/library/algorithm/separator.py +0 -420
- main/library/algorithm/synthesizers.py +0 -590
- main/library/architectures/demucs_separator.py +0 -340
- main/library/architectures/mdx_separator.py +0 -370
- main/library/predictors/FCPE.py +0 -600
- main/library/predictors/RMVPE.py +0 -270
- main/library/uvr5_separator/common_separator.py +0 -270
- main/library/uvr5_separator/demucs/apply.py +0 -280
- main/library/uvr5_separator/demucs/demucs.py +0 -340
- main/library/uvr5_separator/demucs/hdemucs.py +0 -850
- main/library/uvr5_separator/demucs/htdemucs.py +0 -690
- main/library/uvr5_separator/demucs/states.py +0 -70
- main/library/uvr5_separator/demucs/utils.py +0 -10
- main/library/uvr5_separator/spec_utils.py +0 -1100
- main/tools/gdown.py +0 -230
- main/tools/mediafire.py +0 -30
- main/tools/meganz.py +0 -180
- main/tools/pixeldrain.py +0 -19
main/app/app.py
DELETED
The diff for this file is too large to render.
See raw diff
|
|
main/app/clean.py
DELETED
@@ -1,40 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import threading
|
3 |
-
|
4 |
-
from googleapiclient.discovery import build
|
5 |
-
|
6 |
-
|
7 |
-
class Clean:
|
8 |
-
def __init__(self, every=300):
|
9 |
-
self.service = build('drive', 'v3')
|
10 |
-
self.every = every
|
11 |
-
self.trash_cleanup_thread = None
|
12 |
-
|
13 |
-
def delete(self):
|
14 |
-
page_token = None
|
15 |
-
|
16 |
-
while 1:
|
17 |
-
response = self.service.files().list(q="trashed=true", spaces='drive', fields="nextPageToken, files(id, name)", pageToken=page_token).execute()
|
18 |
-
|
19 |
-
for file in response.get('files', []):
|
20 |
-
if file['name'].startswith("G_") or file['name'].startswith("D_"):
|
21 |
-
try:
|
22 |
-
self.service.files().delete(fileId=file['id']).execute()
|
23 |
-
except Exception as e:
|
24 |
-
raise RuntimeError(e)
|
25 |
-
|
26 |
-
page_token = response.get('nextPageToken', None)
|
27 |
-
if page_token is None: break
|
28 |
-
|
29 |
-
def clean(self):
|
30 |
-
while 1:
|
31 |
-
self.delete()
|
32 |
-
time.sleep(self.every)
|
33 |
-
|
34 |
-
def start(self):
|
35 |
-
self.trash_cleanup_thread = threading.Thread(target=self.clean)
|
36 |
-
self.trash_cleanup_thread.daemon = True
|
37 |
-
self.trash_cleanup_thread.start()
|
38 |
-
|
39 |
-
def stop(self):
|
40 |
-
if self.trash_cleanup_thread: self.trash_cleanup_thread.join()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/app/run_tensorboard.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import json
|
4 |
-
import logging
|
5 |
-
import webbrowser
|
6 |
-
|
7 |
-
from tensorboard import program
|
8 |
-
|
9 |
-
sys.path.append(os.getcwd())
|
10 |
-
|
11 |
-
from main.configs.config import Config
|
12 |
-
translations = Config().translations
|
13 |
-
|
14 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
15 |
-
configs = json.load(f)
|
16 |
-
|
17 |
-
def launch_tensorboard_pipeline():
|
18 |
-
logging.getLogger("root").setLevel(logging.ERROR)
|
19 |
-
logging.getLogger("tensorboard").setLevel(logging.ERROR)
|
20 |
-
|
21 |
-
tb = program.TensorBoard()
|
22 |
-
tb.configure(argv=[None, "--logdir", "assets/logs", f"--port={configs["tensorboard_port"]}"])
|
23 |
-
url = tb.launch()
|
24 |
-
|
25 |
-
print(f"{translations['tensorboard_url']}: {url}")
|
26 |
-
webbrowser.open(url)
|
27 |
-
|
28 |
-
return f"{translations['tensorboard_url']}: {url}"
|
29 |
-
|
30 |
-
if __name__ == "__main__": launch_tensorboard_pipeline()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/app/sync.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import time
|
2 |
-
import threading
|
3 |
-
import subprocess
|
4 |
-
|
5 |
-
from typing import List, Union
|
6 |
-
|
7 |
-
|
8 |
-
class Channel:
|
9 |
-
def __init__(self, source, destination, sync_deletions=False, every=60, exclude: Union[str, List, None] = None):
|
10 |
-
self.source = source
|
11 |
-
self.destination = destination
|
12 |
-
self.event = threading.Event()
|
13 |
-
self.syncing_thread = threading.Thread(target=self._sync, args=())
|
14 |
-
self.sync_deletions = sync_deletions
|
15 |
-
self.every = every
|
16 |
-
|
17 |
-
|
18 |
-
if not exclude: exclude = []
|
19 |
-
if isinstance(exclude,str): exclude = [exclude]
|
20 |
-
|
21 |
-
self.exclude = exclude
|
22 |
-
self.command = ['rsync', '-aP']
|
23 |
-
|
24 |
-
|
25 |
-
def alive(self):
|
26 |
-
if self.syncing_thread.is_alive(): return True
|
27 |
-
else: return False
|
28 |
-
|
29 |
-
|
30 |
-
def _sync(self):
|
31 |
-
command = self.command
|
32 |
-
|
33 |
-
for exclusion in self.exclude:
|
34 |
-
command.append(f'--exclude={exclusion}')
|
35 |
-
|
36 |
-
command.extend([f'{self.source}/', f'{self.destination}/'])
|
37 |
-
|
38 |
-
if self.sync_deletions: command.append('--delete')
|
39 |
-
|
40 |
-
while not self.event.is_set():
|
41 |
-
subprocess.run(command)
|
42 |
-
time.sleep(self.every)
|
43 |
-
|
44 |
-
|
45 |
-
def copy(self):
|
46 |
-
command = self.command
|
47 |
-
|
48 |
-
for exclusion in self.exclude:
|
49 |
-
command.append(f'--exclude={exclusion}')
|
50 |
-
|
51 |
-
command.extend([f'{self.source}/', f'{self.destination}/'])
|
52 |
-
|
53 |
-
if self.sync_deletions: command.append('--delete')
|
54 |
-
subprocess.run(command)
|
55 |
-
|
56 |
-
return True
|
57 |
-
|
58 |
-
|
59 |
-
def start(self):
|
60 |
-
if self.syncing_thread.is_alive():
|
61 |
-
self.event.set()
|
62 |
-
self.syncing_thread.join()
|
63 |
-
|
64 |
-
if self.event.is_set(): self.event.clear()
|
65 |
-
if self.syncing_thread._started.is_set(): self.syncing_thread = threading.Thread(target=self._sync, args=())
|
66 |
-
|
67 |
-
self.syncing_thread.start()
|
68 |
-
|
69 |
-
return self.alive()
|
70 |
-
|
71 |
-
|
72 |
-
def stop(self):
|
73 |
-
if self.alive():
|
74 |
-
self.event.set()
|
75 |
-
self.syncing_thread.join()
|
76 |
-
|
77 |
-
while self.alive():
|
78 |
-
if not self.alive(): break
|
79 |
-
|
80 |
-
return not self.alive()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.json
DELETED
@@ -1,229 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"language": "vi-VN",
|
3 |
-
"support_language": [
|
4 |
-
"en-US",
|
5 |
-
"vi-VN"
|
6 |
-
],
|
7 |
-
|
8 |
-
|
9 |
-
"theme": "NoCrypt/miku",
|
10 |
-
"themes": [
|
11 |
-
"NoCrypt/miku",
|
12 |
-
"gstaff/xkcd",
|
13 |
-
"JohnSmith9982/small_and_pretty",
|
14 |
-
"ParityError/Interstellar",
|
15 |
-
"earneleh/paris",
|
16 |
-
"shivi/calm_seafoam",
|
17 |
-
"Hev832/Applio",
|
18 |
-
"YTheme/Minecraft",
|
19 |
-
"gstaff/sketch",
|
20 |
-
"SebastianBravo/simci_css",
|
21 |
-
"allenai/gradio-theme",
|
22 |
-
"Nymbo/Nymbo_Theme_5",
|
23 |
-
"lone17/kotaemon",
|
24 |
-
"Zarkel/IBM_Carbon_Theme",
|
25 |
-
"SherlockRamos/Feliz",
|
26 |
-
"freddyaboulton/dracula_revamped",
|
27 |
-
"freddyaboulton/bad-theme-space",
|
28 |
-
"gradio/dracula_revamped",
|
29 |
-
"abidlabs/dracula_revamped",
|
30 |
-
"gradio/dracula_test",
|
31 |
-
"gradio/seafoam",
|
32 |
-
"gradio/glass",
|
33 |
-
"gradio/monochrome",
|
34 |
-
"gradio/soft",
|
35 |
-
"gradio/default",
|
36 |
-
"gradio/base",
|
37 |
-
"abidlabs/pakistan",
|
38 |
-
"dawood/microsoft_windows",
|
39 |
-
"ysharma/steampunk",
|
40 |
-
"ysharma/huggingface",
|
41 |
-
"abidlabs/Lime",
|
42 |
-
"freddyaboulton/this-theme-does-not-exist-2",
|
43 |
-
"aliabid94/new-theme",
|
44 |
-
"aliabid94/test2",
|
45 |
-
"aliabid94/test3",
|
46 |
-
"aliabid94/test4",
|
47 |
-
"abidlabs/banana",
|
48 |
-
"freddyaboulton/test-blue",
|
49 |
-
"gstaff/whiteboard",
|
50 |
-
"ysharma/llamas",
|
51 |
-
"abidlabs/font-test",
|
52 |
-
"YenLai/Superhuman",
|
53 |
-
"bethecloud/storj_theme",
|
54 |
-
"sudeepshouche/minimalist",
|
55 |
-
"knotdgaf/gradiotest",
|
56 |
-
"ParityError/Anime",
|
57 |
-
"Ajaxon6255/Emerald_Isle",
|
58 |
-
"ParityError/LimeFace",
|
59 |
-
"finlaymacklon/smooth_slate",
|
60 |
-
"finlaymacklon/boxy_violet",
|
61 |
-
"derekzen/stardust",
|
62 |
-
"EveryPizza/Cartoony-Gradio-Theme",
|
63 |
-
"Ifeanyi/Cyanister",
|
64 |
-
"Tshackelton/IBMPlex-DenseReadable",
|
65 |
-
"snehilsanyal/scikit-learn",
|
66 |
-
"Himhimhim/xkcd",
|
67 |
-
"nota-ai/theme",
|
68 |
-
"rawrsor1/Everforest",
|
69 |
-
"rottenlittlecreature/Moon_Goblin",
|
70 |
-
"abidlabs/test-yellow",
|
71 |
-
"abidlabs/test-yellow3",
|
72 |
-
"idspicQstitho/dracula_revamped",
|
73 |
-
"kfahn/AnimalPose",
|
74 |
-
"HaleyCH/HaleyCH_Theme",
|
75 |
-
"simulKitke/dracula_test",
|
76 |
-
"braintacles/CrimsonNight",
|
77 |
-
"wentaohe/whiteboardv2",
|
78 |
-
"reilnuud/polite",
|
79 |
-
"remilia/Ghostly",
|
80 |
-
"Franklisi/darkmode",
|
81 |
-
"coding-alt/soft",
|
82 |
-
"xiaobaiyuan/theme_land",
|
83 |
-
"step-3-profit/Midnight-Deep",
|
84 |
-
"xiaobaiyuan/theme_demo",
|
85 |
-
"Taithrah/Minimal",
|
86 |
-
"Insuz/SimpleIndigo",
|
87 |
-
"zkunn/Alipay_Gradio_theme",
|
88 |
-
"Insuz/Mocha",
|
89 |
-
"xiaobaiyuan/theme_brief",
|
90 |
-
"Ama434/434-base-Barlow",
|
91 |
-
"Ama434/def_barlow",
|
92 |
-
"Ama434/neutral-barlow",
|
93 |
-
"dawood/dracula_test",
|
94 |
-
"nuttea/Softblue",
|
95 |
-
"BlueDancer/Alien_Diffusion",
|
96 |
-
"naughtondale/monochrome",
|
97 |
-
"Dagfinn1962/standard",
|
98 |
-
"default"
|
99 |
-
],
|
100 |
-
|
101 |
-
|
102 |
-
"tts_voice": [
|
103 |
-
"af-ZA-AdriNeural", "af-ZA-WillemNeural", "sq-AL-AnilaNeural",
|
104 |
-
"sq-AL-IlirNeural", "am-ET-AmehaNeural", "am-ET-MekdesNeural",
|
105 |
-
"ar-DZ-AminaNeural", "ar-DZ-IsmaelNeural", "ar-BH-AliNeural",
|
106 |
-
"ar-BH-LailaNeural", "ar-EG-SalmaNeural", "ar-EG-ShakirNeural",
|
107 |
-
"ar-IQ-BasselNeural", "ar-IQ-RanaNeural", "ar-JO-SanaNeural",
|
108 |
-
"ar-JO-TaimNeural", "ar-KW-FahedNeural", "ar-KW-NouraNeural",
|
109 |
-
"ar-LB-LaylaNeural", "ar-LB-RamiNeural", "ar-LY-ImanNeural",
|
110 |
-
"ar-LY-OmarNeural", "ar-MA-JamalNeural", "ar-MA-MounaNeural",
|
111 |
-
"ar-OM-AbdullahNeural", "ar-OM-AyshaNeural", "ar-QA-AmalNeural",
|
112 |
-
"ar-QA-MoazNeural", "ar-SA-HamedNeural", "ar-SA-ZariyahNeural",
|
113 |
-
"ar-SY-AmanyNeural", "ar-SY-LaithNeural", "ar-TN-HediNeural",
|
114 |
-
"ar-TN-ReemNeural", "ar-AE-FatimaNeural", "ar-AE-HamdanNeural",
|
115 |
-
"ar-YE-MaryamNeural", "ar-YE-SalehNeural", "az-AZ-BabekNeural",
|
116 |
-
"az-AZ-BanuNeural", "bn-BD-NabanitaNeural", "bn-BD-PradeepNeural",
|
117 |
-
"bn-IN-BashkarNeural", "bn-IN-TanishaaNeural", "bs-BA-GoranNeural",
|
118 |
-
"bs-BA-VesnaNeural", "bg-BG-BorislavNeural", "bg-BG-KalinaNeural",
|
119 |
-
"my-MM-NilarNeural", "my-MM-ThihaNeural", "ca-ES-EnricNeural",
|
120 |
-
"ca-ES-JoanaNeural", "zh-HK-HiuGaaiNeural", "zh-HK-HiuMaanNeural",
|
121 |
-
"zh-HK-WanLungNeural", "zh-CN-XiaoxiaoNeural", "zh-CN-XiaoyiNeural",
|
122 |
-
"zh-CN-YunjianNeural", "zh-CN-YunxiNeural", "zh-CN-YunxiaNeural",
|
123 |
-
"zh-CN-YunyangNeural", "zh-CN-liaoning-XiaobeiNeural", "zh-TW-HsiaoChenNeural",
|
124 |
-
"zh-TW-YunJheNeural", "zh-TW-HsiaoYuNeural", "zh-CN-shaanxi-XiaoniNeural",
|
125 |
-
"hr-HR-GabrijelaNeural", "hr-HR-SreckoNeural", "cs-CZ-AntoninNeural",
|
126 |
-
"cs-CZ-VlastaNeural", "da-DK-ChristelNeural", "da-DK-JeppeNeural",
|
127 |
-
"nl-BE-ArnaudNeural", "nl-BE-DenaNeural", "nl-NL-ColetteNeural",
|
128 |
-
"nl-NL-FennaNeural", "nl-NL-MaartenNeural", "en-AU-NatashaNeural",
|
129 |
-
"en-AU-WilliamNeural", "en-CA-ClaraNeural", "en-CA-LiamNeural",
|
130 |
-
"en-HK-SamNeural", "en-HK-YanNeural", "en-IN-NeerjaExpressiveNeural",
|
131 |
-
"en-IN-NeerjaNeural", "en-IN-PrabhatNeural", "en-IE-ConnorNeural",
|
132 |
-
"en-IE-EmilyNeural", "en-KE-AsiliaNeural", "en-KE-ChilembaNeural",
|
133 |
-
"en-NZ-MitchellNeural", "en-NZ-MollyNeural", "en-NG-AbeoNeural",
|
134 |
-
"en-NG-EzinneNeural", "en-PH-JamesNeural", "en-PH-RosaNeural",
|
135 |
-
"en-SG-LunaNeural", "en-SG-WayneNeural", "en-ZA-LeahNeural",
|
136 |
-
"en-ZA-LukeNeural", "en-TZ-ElimuNeural", "en-TZ-ImaniNeural",
|
137 |
-
"en-GB-LibbyNeural", "en-GB-MaisieNeural", "en-GB-RyanNeural",
|
138 |
-
"en-GB-SoniaNeural", "en-GB-ThomasNeural", "en-US-AvaMultilingualNeural",
|
139 |
-
"en-US-AndrewMultilingualNeural", "en-US-EmmaMultilingualNeural",
|
140 |
-
"en-US-BrianMultilingualNeural", "en-US-AvaNeural", "en-US-AndrewNeural",
|
141 |
-
"en-US-EmmaNeural", "en-US-BrianNeural", "en-US-AnaNeural", "en-US-AriaNeural",
|
142 |
-
"en-US-ChristopherNeural", "en-US-EricNeural", "en-US-GuyNeural",
|
143 |
-
"en-US-JennyNeural", "en-US-MichelleNeural", "en-US-RogerNeural",
|
144 |
-
"en-US-SteffanNeural", "et-EE-AnuNeural", "et-EE-KertNeural",
|
145 |
-
"fil-PH-AngeloNeural", "fil-PH-BlessicaNeural", "fi-FI-HarriNeural",
|
146 |
-
"fi-FI-NooraNeural", "fr-BE-CharlineNeural", "fr-BE-GerardNeural",
|
147 |
-
"fr-CA-ThierryNeural", "fr-CA-AntoineNeural", "fr-CA-JeanNeural",
|
148 |
-
"fr-CA-SylvieNeural", "fr-FR-VivienneMultilingualNeural", "fr-FR-RemyMultilingualNeural",
|
149 |
-
"fr-FR-DeniseNeural", "fr-FR-EloiseNeural", "fr-FR-HenriNeural",
|
150 |
-
"fr-CH-ArianeNeural", "fr-CH-FabriceNeural", "gl-ES-RoiNeural",
|
151 |
-
"gl-ES-SabelaNeural", "ka-GE-EkaNeural", "ka-GE-GiorgiNeural",
|
152 |
-
"de-AT-IngridNeural", "de-AT-JonasNeural", "de-DE-SeraphinaMultilingualNeural",
|
153 |
-
"de-DE-FlorianMultilingualNeural", "de-DE-AmalaNeural", "de-DE-ConradNeural",
|
154 |
-
"de-DE-KatjaNeural", "de-DE-KillianNeural", "de-CH-JanNeural",
|
155 |
-
"de-CH-LeniNeural", "el-GR-AthinaNeural", "el-GR-NestorasNeural",
|
156 |
-
"gu-IN-DhwaniNeural", "gu-IN-NiranjanNeural", "he-IL-AvriNeural",
|
157 |
-
"he-IL-HilaNeural", "hi-IN-MadhurNeural", "hi-IN-SwaraNeural",
|
158 |
-
"hu-HU-NoemiNeural", "hu-HU-TamasNeural", "is-IS-GudrunNeural",
|
159 |
-
"is-IS-GunnarNeural", "id-ID-ArdiNeural", "id-ID-GadisNeural",
|
160 |
-
"ga-IE-ColmNeural", "ga-IE-OrlaNeural", "it-IT-GiuseppeNeural",
|
161 |
-
"it-IT-DiegoNeural", "it-IT-ElsaNeural", "it-IT-IsabellaNeural",
|
162 |
-
"ja-JP-KeitaNeural", "ja-JP-NanamiNeural", "jv-ID-DimasNeural",
|
163 |
-
"jv-ID-SitiNeural", "kn-IN-GaganNeural", "kn-IN-SapnaNeural",
|
164 |
-
"kk-KZ-AigulNeural", "kk-KZ-DauletNeural", "km-KH-PisethNeural",
|
165 |
-
"km-KH-SreymomNeural", "ko-KR-HyunsuNeural", "ko-KR-InJoonNeural",
|
166 |
-
"ko-KR-SunHiNeural", "lo-LA-ChanthavongNeural", "lo-LA-KeomanyNeural",
|
167 |
-
"lv-LV-EveritaNeural", "lv-LV-NilsNeural", "lt-LT-LeonasNeural",
|
168 |
-
"lt-LT-OnaNeural", "mk-MK-AleksandarNeural", "mk-MK-MarijaNeural",
|
169 |
-
"ms-MY-OsmanNeural", "ms-MY-YasminNeural", "ml-IN-MidhunNeural",
|
170 |
-
"ml-IN-SobhanaNeural", "mt-MT-GraceNeural", "mt-MT-JosephNeural",
|
171 |
-
"mr-IN-AarohiNeural", "mr-IN-ManoharNeural", "mn-MN-BataaNeural",
|
172 |
-
"mn-MN-YesuiNeural", "ne-NP-HemkalaNeural", "ne-NP-SagarNeural",
|
173 |
-
"nb-NO-FinnNeural", "nb-NO-PernilleNeural", "ps-AF-GulNawazNeural",
|
174 |
-
"ps-AF-LatifaNeural", "fa-IR-DilaraNeural", "fa-IR-FaridNeural",
|
175 |
-
"pl-PL-MarekNeural", "pl-PL-ZofiaNeural", "pt-BR-ThalitaNeural",
|
176 |
-
"pt-BR-AntonioNeural", "pt-BR-FranciscaNeural", "pt-PT-DuarteNeural",
|
177 |
-
"pt-PT-RaquelNeural", "ro-RO-AlinaNeural", "ro-RO-EmilNeural",
|
178 |
-
"ru-RU-DmitryNeural", "ru-RU-SvetlanaNeural", "sr-RS-NicholasNeural",
|
179 |
-
"sr-RS-SophieNeural", "si-LK-SameeraNeural", "si-LK-ThiliniNeural",
|
180 |
-
"sk-SK-LukasNeural", "sk-SK-ViktoriaNeural", "sl-SI-PetraNeural",
|
181 |
-
"sl-SI-RokNeural", "so-SO-MuuseNeural", "so-SO-UbaxNeural",
|
182 |
-
"es-AR-ElenaNeural", "es-AR-TomasNeural", "es-BO-MarceloNeural",
|
183 |
-
"es-BO-SofiaNeural", "es-CL-CatalinaNeural", "es-CL-LorenzoNeural",
|
184 |
-
"es-ES-XimenaNeural", "es-CO-GonzaloNeural", "es-CO-SalomeNeural",
|
185 |
-
"es-CR-JuanNeural", "es-CR-MariaNeural", "es-CU-BelkysNeural",
|
186 |
-
"es-CU-ManuelNeural", "es-DO-EmilioNeural", "es-DO-RamonaNeural",
|
187 |
-
"es-EC-AndreaNeural", "es-EC-LuisNeural", "es-SV-LorenaNeural",
|
188 |
-
"es-SV-RodrigoNeural", "es-GQ-JavierNeural", "es-GQ-TeresaNeural",
|
189 |
-
"es-GT-AndresNeural", "es-GT-MartaNeural", "es-HN-CarlosNeural",
|
190 |
-
"es-HN-KarlaNeural", "es-MX-DaliaNeural", "es-MX-JorgeNeural",
|
191 |
-
"es-NI-FedericoNeural", "es-NI-YolandaNeural", "es-PA-MargaritaNeural",
|
192 |
-
"es-PA-RobertoNeural", "es-PY-MarioNeural", "es-PY-TaniaNeural",
|
193 |
-
"es-PE-AlexNeural", "es-PE-CamilaNeural", "es-PR-KarinaNeural",
|
194 |
-
"es-PR-VictorNeural", "es-ES-AlvaroNeural", "es-ES-ElviraNeural",
|
195 |
-
"es-US-AlonsoNeural", "es-US-PalomaNeural", "es-UY-MateoNeural",
|
196 |
-
"es-UY-ValentinaNeural", "es-VE-PaolaNeural", "es-VE-SebastianNeural",
|
197 |
-
"su-ID-JajangNeural", "su-ID-TutiNeural", "sw-KE-RafikiNeural",
|
198 |
-
"sw-KE-ZuriNeural", "sw-TZ-DaudiNeural", "sw-TZ-RehemaNeural",
|
199 |
-
"sv-SE-MattiasNeural", "sv-SE-SofieNeural", "ta-IN-PallaviNeural",
|
200 |
-
"ta-IN-ValluvarNeural", "ta-MY-KaniNeural", "ta-MY-SuryaNeural",
|
201 |
-
"ta-SG-AnbuNeural", "ta-SG-VenbaNeural", "ta-LK-KumarNeural",
|
202 |
-
"ta-LK-SaranyaNeural", "te-IN-MohanNeural", "te-IN-ShrutiNeural",
|
203 |
-
"th-TH-NiwatNeural", "th-TH-PremwadeeNeural", "tr-TR-AhmetNeural",
|
204 |
-
"tr-TR-EmelNeural", "uk-UA-OstapNeural", "uk-UA-PolinaNeural",
|
205 |
-
"ur-IN-GulNeural", "ur-IN-SalmanNeural", "ur-PK-AsadNeural",
|
206 |
-
"ur-PK-UzmaNeural", "uz-UZ-MadinaNeural", "uz-UZ-SardorNeural",
|
207 |
-
"vi-VN-HoaiMyNeural", "vi-VN-NamMinhNeural", "cy-GB-AledNeural",
|
208 |
-
"cy-GB-NiaNeural", "zu-ZA-ThandoNeural", "zu-ZA-ThembaNeural"
|
209 |
-
],
|
210 |
-
|
211 |
-
|
212 |
-
"separator_tab": false,
|
213 |
-
"convert_tab": true,
|
214 |
-
"tts_tab": true,
|
215 |
-
"effects_tab": true,
|
216 |
-
"create_dataset_tab": false,
|
217 |
-
"training_tab": false,
|
218 |
-
"fushion_tab": false,
|
219 |
-
"read_tab": false,
|
220 |
-
"downloads_tab": true,
|
221 |
-
"settings_tab": false,
|
222 |
-
|
223 |
-
"app_port": 7860,
|
224 |
-
"tensorboard_port": 6870,
|
225 |
-
"num_of_restart": 5,
|
226 |
-
"server_name": "0.0.0.0",
|
227 |
-
"app_show_error": true,
|
228 |
-
"share": false
|
229 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/config.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import json
|
3 |
-
import torch
|
4 |
-
|
5 |
-
|
6 |
-
version_config_paths = [
|
7 |
-
os.path.join("v1", "32000.json"),
|
8 |
-
os.path.join("v1", "40000.json"),
|
9 |
-
os.path.join("v1", "48000.json"),
|
10 |
-
os.path.join("v2", "48000.json"),
|
11 |
-
os.path.join("v2", "40000.json"),
|
12 |
-
os.path.join("v2", "32000.json"),
|
13 |
-
]
|
14 |
-
|
15 |
-
|
16 |
-
def singleton(cls):
|
17 |
-
instances = {}
|
18 |
-
|
19 |
-
def get_instance(*args, **kwargs):
|
20 |
-
if cls not in instances: instances[cls] = cls(*args, **kwargs)
|
21 |
-
|
22 |
-
return instances[cls]
|
23 |
-
return get_instance
|
24 |
-
|
25 |
-
|
26 |
-
@singleton
|
27 |
-
class Config:
|
28 |
-
def __init__(self):
|
29 |
-
self.device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
30 |
-
self.is_half = self.device != "cpu"
|
31 |
-
self.gpu_name = (torch.cuda.get_device_name(int(self.device.split(":")[-1])) if self.device.startswith("cuda") else None)
|
32 |
-
self.json_config = self.load_config_json()
|
33 |
-
self.translations = self.multi_language()
|
34 |
-
self.gpu_mem = None
|
35 |
-
self.x_pad, self.x_query, self.x_center, self.x_max = self.device_config()
|
36 |
-
|
37 |
-
|
38 |
-
def load_config_json(self) -> dict:
|
39 |
-
configs = {}
|
40 |
-
|
41 |
-
for config_file in version_config_paths:
|
42 |
-
config_path = os.path.join("main", "configs", config_file)
|
43 |
-
|
44 |
-
with open(config_path, "r") as f:
|
45 |
-
configs[config_file] = json.load(f)
|
46 |
-
|
47 |
-
return configs
|
48 |
-
|
49 |
-
def multi_language(self):
|
50 |
-
with open(os.path.join("main", "configs", "config.json"), "r") as f:
|
51 |
-
configs = json.load(f)
|
52 |
-
|
53 |
-
lang = configs["language"]
|
54 |
-
|
55 |
-
if len([l for l in os.listdir(os.path.join("assets", "languages")) if l.endswith(".json")]) < 1: raise FileNotFoundError("Không tìm thấy bất cứ gói ngôn ngữ nào(No package languages found)")
|
56 |
-
|
57 |
-
if not lang: lang = "vi-VN"
|
58 |
-
if lang not in configs["support_language"]: raise ValueError("Ngôn ngữ không được hỗ trợ(Language not supported)")
|
59 |
-
|
60 |
-
lang_path = os.path.join("assets", "languages", f"{lang}.json")
|
61 |
-
if not os.path.exists(lang_path): lang_path = os.path.join("assets", "languages", f"vi-VN.json")
|
62 |
-
|
63 |
-
with open(lang_path, encoding="utf-8") as f:
|
64 |
-
translations = json.load(f)
|
65 |
-
|
66 |
-
|
67 |
-
return translations
|
68 |
-
|
69 |
-
def set_precision(self, precision):
|
70 |
-
if precision not in ["fp32", "fp16"]: raise ValueError("Loại chính xác không hợp lệ. Phải là 'fp32' hoặc 'fp16'(Invalid precision type. Must be 'fp32' or 'fp16').")
|
71 |
-
|
72 |
-
fp16_run_value = precision == "fp16"
|
73 |
-
|
74 |
-
for config_path in version_config_paths:
|
75 |
-
full_config_path = os.path.join("main", "configs", config_path)
|
76 |
-
|
77 |
-
try:
|
78 |
-
with open(full_config_path, "r") as f:
|
79 |
-
config = json.load(f)
|
80 |
-
|
81 |
-
config["train"]["fp16_run"] = fp16_run_value
|
82 |
-
|
83 |
-
with open(full_config_path, "w") as f:
|
84 |
-
json.dump(config, f, indent=4)
|
85 |
-
except FileNotFoundError:
|
86 |
-
print(self.translations["not_found"].format(name=full_config_path))
|
87 |
-
|
88 |
-
return self.translations["set_precision"].format(precision=precision)
|
89 |
-
|
90 |
-
def device_config(self) -> tuple:
|
91 |
-
if self.device.startswith("cuda"):
|
92 |
-
self.set_cuda_config()
|
93 |
-
elif self.has_mps():
|
94 |
-
self.device = "mps"
|
95 |
-
self.is_half = False
|
96 |
-
self.set_precision("fp32")
|
97 |
-
else:
|
98 |
-
self.device = "cpu"
|
99 |
-
self.is_half = False
|
100 |
-
self.set_precision("fp32")
|
101 |
-
|
102 |
-
x_pad, x_query, x_center, x_max = ((3, 10, 60, 65) if self.is_half else (1, 6, 38, 41))
|
103 |
-
|
104 |
-
if self.gpu_mem is not None and self.gpu_mem <= 4: x_pad, x_query, x_center, x_max = (1, 5, 30, 32)
|
105 |
-
|
106 |
-
return x_pad, x_query, x_center, x_max
|
107 |
-
|
108 |
-
def set_cuda_config(self):
|
109 |
-
i_device = int(self.device.split(":")[-1])
|
110 |
-
self.gpu_name = torch.cuda.get_device_name(i_device)
|
111 |
-
low_end_gpus = ["16", "P40", "P10", "1060", "1070", "1080"]
|
112 |
-
|
113 |
-
if (any(gpu in self.gpu_name for gpu in low_end_gpus) and "V100" not in self.gpu_name.upper()):
|
114 |
-
self.is_half = False
|
115 |
-
self.set_precision("fp32")
|
116 |
-
|
117 |
-
self.gpu_mem = torch.cuda.get_device_properties(i_device).total_memory // (1024**3)
|
118 |
-
|
119 |
-
def has_mps(self) -> bool:
|
120 |
-
return torch.backends.mps.is_available()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/32000.json
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [
|
8 |
-
0.8,
|
9 |
-
0.99
|
10 |
-
],
|
11 |
-
"eps": 1e-09,
|
12 |
-
"batch_size": 4,
|
13 |
-
"fp16_run": false,
|
14 |
-
"lr_decay": 0.999875,
|
15 |
-
"segment_size": 12800,
|
16 |
-
"init_lr_ratio": 1,
|
17 |
-
"warmup_epochs": 0,
|
18 |
-
"c_mel": 45,
|
19 |
-
"c_kl": 1.0
|
20 |
-
},
|
21 |
-
"data": {
|
22 |
-
"max_wav_value": 32768.0,
|
23 |
-
"sample_rate": 32000,
|
24 |
-
"filter_length": 1024,
|
25 |
-
"hop_length": 320,
|
26 |
-
"win_length": 1024,
|
27 |
-
"n_mel_channels": 80,
|
28 |
-
"mel_fmin": 0.0,
|
29 |
-
"mel_fmax": null
|
30 |
-
},
|
31 |
-
"model": {
|
32 |
-
"inter_channels": 192,
|
33 |
-
"hidden_channels": 192,
|
34 |
-
"filter_channels": 768,
|
35 |
-
"text_enc_hidden_dim": 256,
|
36 |
-
"n_heads": 2,
|
37 |
-
"n_layers": 6,
|
38 |
-
"kernel_size": 3,
|
39 |
-
"p_dropout": 0,
|
40 |
-
"resblock": "1",
|
41 |
-
"resblock_kernel_sizes": [
|
42 |
-
3,
|
43 |
-
7,
|
44 |
-
11
|
45 |
-
],
|
46 |
-
"resblock_dilation_sizes": [
|
47 |
-
[
|
48 |
-
1,
|
49 |
-
3,
|
50 |
-
5
|
51 |
-
],
|
52 |
-
[
|
53 |
-
1,
|
54 |
-
3,
|
55 |
-
5
|
56 |
-
],
|
57 |
-
[
|
58 |
-
1,
|
59 |
-
3,
|
60 |
-
5
|
61 |
-
]
|
62 |
-
],
|
63 |
-
"upsample_rates": [
|
64 |
-
10,
|
65 |
-
4,
|
66 |
-
2,
|
67 |
-
2,
|
68 |
-
2
|
69 |
-
],
|
70 |
-
"upsample_initial_channel": 512,
|
71 |
-
"upsample_kernel_sizes": [
|
72 |
-
16,
|
73 |
-
16,
|
74 |
-
4,
|
75 |
-
4,
|
76 |
-
4
|
77 |
-
],
|
78 |
-
"use_spectral_norm": false,
|
79 |
-
"gin_channels": 256,
|
80 |
-
"spk_embed_dim": 109
|
81 |
-
}
|
82 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/40000.json
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [
|
8 |
-
0.8,
|
9 |
-
0.99
|
10 |
-
],
|
11 |
-
"eps": 1e-09,
|
12 |
-
"batch_size": 4,
|
13 |
-
"fp16_run": false,
|
14 |
-
"lr_decay": 0.999875,
|
15 |
-
"segment_size": 12800,
|
16 |
-
"init_lr_ratio": 1,
|
17 |
-
"warmup_epochs": 0,
|
18 |
-
"c_mel": 45,
|
19 |
-
"c_kl": 1.0
|
20 |
-
},
|
21 |
-
"data": {
|
22 |
-
"max_wav_value": 32768.0,
|
23 |
-
"sample_rate": 40000,
|
24 |
-
"filter_length": 2048,
|
25 |
-
"hop_length": 400,
|
26 |
-
"win_length": 2048,
|
27 |
-
"n_mel_channels": 125,
|
28 |
-
"mel_fmin": 0.0,
|
29 |
-
"mel_fmax": null
|
30 |
-
},
|
31 |
-
"model": {
|
32 |
-
"inter_channels": 192,
|
33 |
-
"hidden_channels": 192,
|
34 |
-
"filter_channels": 768,
|
35 |
-
"text_enc_hidden_dim": 256,
|
36 |
-
"n_heads": 2,
|
37 |
-
"n_layers": 6,
|
38 |
-
"kernel_size": 3,
|
39 |
-
"p_dropout": 0,
|
40 |
-
"resblock": "1",
|
41 |
-
"resblock_kernel_sizes": [
|
42 |
-
3,
|
43 |
-
7,
|
44 |
-
11
|
45 |
-
],
|
46 |
-
"resblock_dilation_sizes": [
|
47 |
-
[
|
48 |
-
1,
|
49 |
-
3,
|
50 |
-
5
|
51 |
-
],
|
52 |
-
[
|
53 |
-
1,
|
54 |
-
3,
|
55 |
-
5
|
56 |
-
],
|
57 |
-
[
|
58 |
-
1,
|
59 |
-
3,
|
60 |
-
5
|
61 |
-
]
|
62 |
-
],
|
63 |
-
"upsample_rates": [
|
64 |
-
10,
|
65 |
-
10,
|
66 |
-
2,
|
67 |
-
2
|
68 |
-
],
|
69 |
-
"upsample_initial_channel": 512,
|
70 |
-
"upsample_kernel_sizes": [
|
71 |
-
16,
|
72 |
-
16,
|
73 |
-
4,
|
74 |
-
4
|
75 |
-
],
|
76 |
-
"use_spectral_norm": false,
|
77 |
-
"gin_channels": 256,
|
78 |
-
"spk_embed_dim": 109
|
79 |
-
}
|
80 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v1/48000.json
DELETED
@@ -1,82 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"epochs": 20000,
|
6 |
-
"learning_rate": 0.0001,
|
7 |
-
"betas": [
|
8 |
-
0.8,
|
9 |
-
0.99
|
10 |
-
],
|
11 |
-
"eps": 1e-09,
|
12 |
-
"batch_size": 4,
|
13 |
-
"fp16_run": false,
|
14 |
-
"lr_decay": 0.999875,
|
15 |
-
"segment_size": 11520,
|
16 |
-
"init_lr_ratio": 1,
|
17 |
-
"warmup_epochs": 0,
|
18 |
-
"c_mel": 45,
|
19 |
-
"c_kl": 1.0
|
20 |
-
},
|
21 |
-
"data": {
|
22 |
-
"max_wav_value": 32768.0,
|
23 |
-
"sample_rate": 48000,
|
24 |
-
"filter_length": 2048,
|
25 |
-
"hop_length": 480,
|
26 |
-
"win_length": 2048,
|
27 |
-
"n_mel_channels": 128,
|
28 |
-
"mel_fmin": 0.0,
|
29 |
-
"mel_fmax": null
|
30 |
-
},
|
31 |
-
"model": {
|
32 |
-
"inter_channels": 192,
|
33 |
-
"hidden_channels": 192,
|
34 |
-
"filter_channels": 768,
|
35 |
-
"text_enc_hidden_dim": 256,
|
36 |
-
"n_heads": 2,
|
37 |
-
"n_layers": 6,
|
38 |
-
"kernel_size": 3,
|
39 |
-
"p_dropout": 0,
|
40 |
-
"resblock": "1",
|
41 |
-
"resblock_kernel_sizes": [
|
42 |
-
3,
|
43 |
-
7,
|
44 |
-
11
|
45 |
-
],
|
46 |
-
"resblock_dilation_sizes": [
|
47 |
-
[
|
48 |
-
1,
|
49 |
-
3,
|
50 |
-
5
|
51 |
-
],
|
52 |
-
[
|
53 |
-
1,
|
54 |
-
3,
|
55 |
-
5
|
56 |
-
],
|
57 |
-
[
|
58 |
-
1,
|
59 |
-
3,
|
60 |
-
5
|
61 |
-
]
|
62 |
-
],
|
63 |
-
"upsample_rates": [
|
64 |
-
10,
|
65 |
-
6,
|
66 |
-
2,
|
67 |
-
2,
|
68 |
-
2
|
69 |
-
],
|
70 |
-
"upsample_initial_channel": 512,
|
71 |
-
"upsample_kernel_sizes": [
|
72 |
-
16,
|
73 |
-
16,
|
74 |
-
4,
|
75 |
-
4,
|
76 |
-
4
|
77 |
-
],
|
78 |
-
"use_spectral_norm": false,
|
79 |
-
"gin_channels": 256,
|
80 |
-
"spk_embed_dim": 109
|
81 |
-
}
|
82 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/32000.json
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [
|
7 |
-
0.8,
|
8 |
-
0.99
|
9 |
-
],
|
10 |
-
"eps": 1e-09,
|
11 |
-
"fp16_run": false,
|
12 |
-
"lr_decay": 0.999875,
|
13 |
-
"segment_size": 12800,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 32000,
|
20 |
-
"filter_length": 1024,
|
21 |
-
"hop_length": 320,
|
22 |
-
"win_length": 1024,
|
23 |
-
"n_mel_channels": 80,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 768,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [
|
38 |
-
3,
|
39 |
-
7,
|
40 |
-
11
|
41 |
-
],
|
42 |
-
"resblock_dilation_sizes": [
|
43 |
-
[
|
44 |
-
1,
|
45 |
-
3,
|
46 |
-
5
|
47 |
-
],
|
48 |
-
[
|
49 |
-
1,
|
50 |
-
3,
|
51 |
-
5
|
52 |
-
],
|
53 |
-
[
|
54 |
-
1,
|
55 |
-
3,
|
56 |
-
5
|
57 |
-
]
|
58 |
-
],
|
59 |
-
"upsample_rates": [
|
60 |
-
10,
|
61 |
-
8,
|
62 |
-
2,
|
63 |
-
2
|
64 |
-
],
|
65 |
-
"upsample_initial_channel": 512,
|
66 |
-
"upsample_kernel_sizes": [
|
67 |
-
20,
|
68 |
-
16,
|
69 |
-
4,
|
70 |
-
4
|
71 |
-
],
|
72 |
-
"use_spectral_norm": false,
|
73 |
-
"gin_channels": 256,
|
74 |
-
"spk_embed_dim": 109
|
75 |
-
}
|
76 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/40000.json
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [
|
7 |
-
0.8,
|
8 |
-
0.99
|
9 |
-
],
|
10 |
-
"eps": 1e-09,
|
11 |
-
"fp16_run": false,
|
12 |
-
"lr_decay": 0.999875,
|
13 |
-
"segment_size": 12800,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 40000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 400,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 125,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 768,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [
|
38 |
-
3,
|
39 |
-
7,
|
40 |
-
11
|
41 |
-
],
|
42 |
-
"resblock_dilation_sizes": [
|
43 |
-
[
|
44 |
-
1,
|
45 |
-
3,
|
46 |
-
5
|
47 |
-
],
|
48 |
-
[
|
49 |
-
1,
|
50 |
-
3,
|
51 |
-
5
|
52 |
-
],
|
53 |
-
[
|
54 |
-
1,
|
55 |
-
3,
|
56 |
-
5
|
57 |
-
]
|
58 |
-
],
|
59 |
-
"upsample_rates": [
|
60 |
-
10,
|
61 |
-
10,
|
62 |
-
2,
|
63 |
-
2
|
64 |
-
],
|
65 |
-
"upsample_initial_channel": 512,
|
66 |
-
"upsample_kernel_sizes": [
|
67 |
-
16,
|
68 |
-
16,
|
69 |
-
4,
|
70 |
-
4
|
71 |
-
],
|
72 |
-
"use_spectral_norm": false,
|
73 |
-
"gin_channels": 256,
|
74 |
-
"spk_embed_dim": 109
|
75 |
-
}
|
76 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/configs/v2/48000.json
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
{
|
2 |
-
"train": {
|
3 |
-
"log_interval": 200,
|
4 |
-
"seed": 1234,
|
5 |
-
"learning_rate": 0.0001,
|
6 |
-
"betas": [
|
7 |
-
0.8,
|
8 |
-
0.99
|
9 |
-
],
|
10 |
-
"eps": 1e-09,
|
11 |
-
"fp16_run": false,
|
12 |
-
"lr_decay": 0.999875,
|
13 |
-
"segment_size": 17280,
|
14 |
-
"c_mel": 45,
|
15 |
-
"c_kl": 1.0
|
16 |
-
},
|
17 |
-
"data": {
|
18 |
-
"max_wav_value": 32768.0,
|
19 |
-
"sample_rate": 48000,
|
20 |
-
"filter_length": 2048,
|
21 |
-
"hop_length": 480,
|
22 |
-
"win_length": 2048,
|
23 |
-
"n_mel_channels": 128,
|
24 |
-
"mel_fmin": 0.0,
|
25 |
-
"mel_fmax": null
|
26 |
-
},
|
27 |
-
"model": {
|
28 |
-
"inter_channels": 192,
|
29 |
-
"hidden_channels": 192,
|
30 |
-
"filter_channels": 768,
|
31 |
-
"text_enc_hidden_dim": 768,
|
32 |
-
"n_heads": 2,
|
33 |
-
"n_layers": 6,
|
34 |
-
"kernel_size": 3,
|
35 |
-
"p_dropout": 0,
|
36 |
-
"resblock": "1",
|
37 |
-
"resblock_kernel_sizes": [
|
38 |
-
3,
|
39 |
-
7,
|
40 |
-
11
|
41 |
-
],
|
42 |
-
"resblock_dilation_sizes": [
|
43 |
-
[
|
44 |
-
1,
|
45 |
-
3,
|
46 |
-
5
|
47 |
-
],
|
48 |
-
[
|
49 |
-
1,
|
50 |
-
3,
|
51 |
-
5
|
52 |
-
],
|
53 |
-
[
|
54 |
-
1,
|
55 |
-
3,
|
56 |
-
5
|
57 |
-
]
|
58 |
-
],
|
59 |
-
"upsample_rates": [
|
60 |
-
12,
|
61 |
-
10,
|
62 |
-
2,
|
63 |
-
2
|
64 |
-
],
|
65 |
-
"upsample_initial_channel": 512,
|
66 |
-
"upsample_kernel_sizes": [
|
67 |
-
24,
|
68 |
-
20,
|
69 |
-
4,
|
70 |
-
4
|
71 |
-
],
|
72 |
-
"use_spectral_norm": false,
|
73 |
-
"gin_channels": 256,
|
74 |
-
"spk_embed_dim": 109
|
75 |
-
}
|
76 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/audio_effects.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import librosa
|
4 |
-
import argparse
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import soundfile as sf
|
8 |
-
|
9 |
-
from distutils.util import strtobool
|
10 |
-
from scipy.signal import butter, filtfilt
|
11 |
-
from pedalboard import Pedalboard, Chorus, Distortion, Reverb, PitchShift, Delay, Limiter, Gain, Bitcrush, Clipping, Compressor, Phaser
|
12 |
-
|
13 |
-
now_dir = os.getcwd()
|
14 |
-
sys.path.append(now_dir)
|
15 |
-
|
16 |
-
from main.configs.config import Config
|
17 |
-
translations = Config().translations
|
18 |
-
|
19 |
-
def parse_arguments() -> tuple:
|
20 |
-
parser = argparse.ArgumentParser()
|
21 |
-
parser.add_argument("--input_path", type=str, required=True)
|
22 |
-
parser.add_argument("--output_path", type=str, default="./audios/apply_effects.wav")
|
23 |
-
parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
|
24 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
25 |
-
parser.add_argument("--chorus", type=lambda x: bool(strtobool(x)), default=False)
|
26 |
-
parser.add_argument("--chorus_depth", type=float, default=0.5)
|
27 |
-
parser.add_argument("--chorus_rate", type=float, default=1.5)
|
28 |
-
parser.add_argument("--chorus_mix", type=float, default=0.5)
|
29 |
-
parser.add_argument("--chorus_delay", type=int, default=10)
|
30 |
-
parser.add_argument("--chorus_feedback", type=float, default=0)
|
31 |
-
parser.add_argument("--distortion", type=lambda x: bool(strtobool(x)), default=False)
|
32 |
-
parser.add_argument("--drive_db", type=int, default=20)
|
33 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
34 |
-
parser.add_argument("--reverb_room_size", type=float, default=0.5)
|
35 |
-
parser.add_argument("--reverb_damping", type=float, default=0.5)
|
36 |
-
parser.add_argument("--reverb_wet_level", type=float, default=0.33)
|
37 |
-
parser.add_argument("--reverb_dry_level", type=float, default=0.67)
|
38 |
-
parser.add_argument("--reverb_width", type=float, default=1)
|
39 |
-
parser.add_argument("--reverb_freeze_mode", type=lambda x: bool(strtobool(x)), default=False)
|
40 |
-
parser.add_argument("--pitchshift", type=lambda x: bool(strtobool(x)), default=False)
|
41 |
-
parser.add_argument("--pitch_shift", type=int, default=0)
|
42 |
-
parser.add_argument("--delay", type=lambda x: bool(strtobool(x)), default=False)
|
43 |
-
parser.add_argument("--delay_seconds", type=float, default=0.5)
|
44 |
-
parser.add_argument("--delay_feedback", type=float, default=0.5)
|
45 |
-
parser.add_argument("--delay_mix", type=float, default=0.5)
|
46 |
-
parser.add_argument("--compressor", type=lambda x: bool(strtobool(x)), default=False)
|
47 |
-
parser.add_argument("--compressor_threshold", type=int, default=-20)
|
48 |
-
parser.add_argument("--compressor_ratio", type=float, default=4)
|
49 |
-
parser.add_argument("--compressor_attack_ms", type=float, default=10)
|
50 |
-
parser.add_argument("--compressor_release_ms", type=int, default=200)
|
51 |
-
parser.add_argument("--limiter", type=lambda x: bool(strtobool(x)), default=False)
|
52 |
-
parser.add_argument("--limiter_threshold", type=int, default=0)
|
53 |
-
parser.add_argument("--limiter_release", type=int, default=100)
|
54 |
-
parser.add_argument("--gain", type=lambda x: bool(strtobool(x)), default=False)
|
55 |
-
parser.add_argument("--gain_db", type=int, default=0)
|
56 |
-
parser.add_argument("--bitcrush", type=lambda x: bool(strtobool(x)), default=False)
|
57 |
-
parser.add_argument("--bitcrush_bit_depth", type=int, default=16)
|
58 |
-
parser.add_argument("--clipping", type=lambda x: bool(strtobool(x)), default=False)
|
59 |
-
parser.add_argument("--clipping_threshold", type=int, default=-10)
|
60 |
-
parser.add_argument("--phaser", type=lambda x: bool(strtobool(x)), default=False)
|
61 |
-
parser.add_argument("--phaser_rate_hz", type=float, default=0.5)
|
62 |
-
parser.add_argument("--phaser_depth", type=float, default=0.5)
|
63 |
-
parser.add_argument("--phaser_centre_frequency_hz", type=int, default=1000)
|
64 |
-
parser.add_argument("--phaser_feedback", type=float, default=0)
|
65 |
-
parser.add_argument("--phaser_mix", type=float, default=0.5)
|
66 |
-
parser.add_argument("--treble_bass_boost", type=lambda x: bool(strtobool(x)), default=False)
|
67 |
-
parser.add_argument("--bass_boost_db", type=int, default=0)
|
68 |
-
parser.add_argument("--bass_boost_frequency", type=int, default=100)
|
69 |
-
parser.add_argument("--treble_boost_db", type=int, default=0)
|
70 |
-
parser.add_argument("--treble_boost_frequency", type=int, default=3000)
|
71 |
-
parser.add_argument("--fade_in_out", type=lambda x: bool(strtobool(x)), default=False)
|
72 |
-
parser.add_argument("--fade_in_duration", type=float, default=2000)
|
73 |
-
parser.add_argument("--fade_out_duration", type=float, default=2000)
|
74 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
75 |
-
|
76 |
-
args = parser.parse_args()
|
77 |
-
return args
|
78 |
-
|
79 |
-
|
80 |
-
def main():
|
81 |
-
args = parse_arguments()
|
82 |
-
|
83 |
-
process_audio(input_path=args.input_path, output_path=args.output_path, resample=args.resample, resample_sr=args.resample_sr, chorus_depth=args.chorus_depth, chorus_rate=args.chorus_rate, chorus_mix=args.chorus_mix, chorus_delay=args.chorus_delay, chorus_feedback=args.chorus_feedback, distortion_drive=args.drive_db, reverb_room_size=args.reverb_room_size, reverb_damping=args.reverb_damping, reverb_wet_level=args.reverb_wet_level, reverb_dry_level=args.reverb_dry_level, reverb_width=args.reverb_width, reverb_freeze_mode=args.reverb_freeze_mode, pitch_shift=args.pitch_shift, delay_seconds=args.delay_seconds, delay_feedback=args.delay_feedback, delay_mix=args.delay_mix, compressor_threshold=args.compressor_threshold, compressor_ratio=args.compressor_ratio, compressor_attack_ms=args.compressor_attack_ms, compressor_release_ms=args.compressor_release_ms, limiter_threshold=args.limiter_threshold, limiter_release=args.limiter_release, gain_db=args.gain_db, bitcrush_bit_depth=args.bitcrush_bit_depth, clipping_threshold=args.clipping_threshold, phaser_rate_hz=args.phaser_rate_hz, phaser_depth=args.phaser_depth, phaser_centre_frequency_hz=args.phaser_centre_frequency_hz, phaser_feedback=args.phaser_feedback, phaser_mix=args.phaser_mix, bass_boost_db=args.bass_boost_db, bass_boost_frequency=args.bass_boost_frequency, treble_boost_db=args.treble_boost_db, treble_boost_frequency=args.treble_boost_frequency, fade_in_duration=args.fade_in_duration, fade_out_duration=args.fade_out_duration, export_format=args.export_format, chorus=args.chorus, distortion=args.distortion, reverb=args.reverb, pitchshift=args.pitchshift, delay=args.delay, compressor=args.compressor, limiter=args.limiter, gain=args.gain, bitcrush=args.bitcrush, clipping=args.clipping, phaser=args.phaser, treble_bass_boost=args.treble_bass_boost, fade_in_out=args.fade_in_out)
|
84 |
-
|
85 |
-
|
86 |
-
def process_audio(input_path, output_path, resample, resample_sr, chorus_depth, chorus_rate, chorus_mix, chorus_delay, chorus_feedback, distortion_drive, reverb_room_size, reverb_damping, reverb_wet_level, reverb_dry_level, reverb_width, reverb_freeze_mode, pitch_shift, delay_seconds, delay_feedback, delay_mix, compressor_threshold, compressor_ratio, compressor_attack_ms, compressor_release_ms, limiter_threshold, limiter_release, gain_db, bitcrush_bit_depth, clipping_threshold, phaser_rate_hz, phaser_depth, phaser_centre_frequency_hz, phaser_feedback, phaser_mix, bass_boost_db, bass_boost_frequency, treble_boost_db, treble_boost_frequency, fade_in_duration, fade_out_duration, export_format, chorus, distortion, reverb, pitchshift, delay, compressor, limiter, gain, bitcrush, clipping, phaser, treble_bass_boost, fade_in_out):
|
87 |
-
def bass_boost(audio, gain_db, frequency, sample_rate):
|
88 |
-
if gain_db >= 1:
|
89 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='low')
|
90 |
-
|
91 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
92 |
-
else: return audio
|
93 |
-
|
94 |
-
def treble_boost(audio, gain_db, frequency, sample_rate):
|
95 |
-
if gain_db >=1:
|
96 |
-
b, a = butter(4, frequency / (0.5 * sample_rate), btype='high')
|
97 |
-
|
98 |
-
return filtfilt(b, a, audio) * 10 ** (gain_db / 20)
|
99 |
-
else: return audio
|
100 |
-
|
101 |
-
def fade_out_effect(audio, sr, duration=3.0):
|
102 |
-
length = int(duration * sr)
|
103 |
-
end = audio.shape[0]
|
104 |
-
|
105 |
-
if length > end: length = end
|
106 |
-
start = end - length
|
107 |
-
|
108 |
-
audio[start:end] = audio[start:end] * np.linspace(1.0, 0.0, length)
|
109 |
-
return audio
|
110 |
-
|
111 |
-
def fade_in_effect(audio, sr, duration=3.0):
|
112 |
-
length = int(duration * sr)
|
113 |
-
start = 0
|
114 |
-
|
115 |
-
if length > audio.shape[0]: length = audio.shape[0]
|
116 |
-
end = length
|
117 |
-
|
118 |
-
audio[start:end] = audio[start:end] * np.linspace(0.0, 1.0, length)
|
119 |
-
return audio
|
120 |
-
|
121 |
-
if not input_path or not os.path.exists(input_path):
|
122 |
-
print(translations["input_not_valid"])
|
123 |
-
sys.exit(1)
|
124 |
-
|
125 |
-
if not output_path:
|
126 |
-
print(translations["output_not_valid"])
|
127 |
-
sys.exit(1)
|
128 |
-
|
129 |
-
if os.path.exists(output_path): os.remove(output_path)
|
130 |
-
|
131 |
-
try:
|
132 |
-
audio, sample_rate = sf.read(input_path)
|
133 |
-
except Exception as e:
|
134 |
-
raise RuntimeError(translations["errors_loading_audio"].format(e=e))
|
135 |
-
|
136 |
-
try:
|
137 |
-
board = Pedalboard()
|
138 |
-
|
139 |
-
if chorus: board.append(Chorus(depth=chorus_depth, rate_hz=chorus_rate, mix=chorus_mix, centre_delay_ms=chorus_delay, feedback=chorus_feedback))
|
140 |
-
if distortion: board.append(Distortion(drive_db=distortion_drive))
|
141 |
-
if reverb: board.append(Reverb(room_size=reverb_room_size, damping=reverb_damping, wet_level=reverb_wet_level, dry_level=reverb_dry_level, width=reverb_width, freeze_mode=1 if reverb_freeze_mode else 0))
|
142 |
-
if pitchshift: board.append(PitchShift(semitones=pitch_shift))
|
143 |
-
if delay: board.append(Delay(delay_seconds=delay_seconds, feedback=delay_feedback, mix=delay_mix))
|
144 |
-
if compressor: board.append(Compressor(threshold_db=compressor_threshold, ratio=compressor_ratio, attack_ms=compressor_attack_ms, release_ms=compressor_release_ms))
|
145 |
-
if limiter: board.append(Limiter(threshold_db=limiter_threshold, release_ms=limiter_release))
|
146 |
-
if gain: board.append(Gain(gain_db=gain_db))
|
147 |
-
if bitcrush: board.append(Bitcrush(bit_depth=bitcrush_bit_depth))
|
148 |
-
if clipping: board.append(Clipping(threshold_db=clipping_threshold))
|
149 |
-
if phaser: board.append(Phaser(rate_hz=phaser_rate_hz, depth=phaser_depth, centre_frequency_hz=phaser_centre_frequency_hz, feedback=phaser_feedback, mix=phaser_mix))
|
150 |
-
|
151 |
-
processed_audio = board(audio, sample_rate)
|
152 |
-
|
153 |
-
if treble_bass_boost:
|
154 |
-
processed_audio = bass_boost(processed_audio, bass_boost_db, bass_boost_frequency, sample_rate)
|
155 |
-
processed_audio = treble_boost(processed_audio, treble_boost_db, treble_boost_frequency, sample_rate)
|
156 |
-
|
157 |
-
if fade_in_out:
|
158 |
-
processed_audio = fade_in_effect(processed_audio, sample_rate, fade_in_duration)
|
159 |
-
processed_audio = fade_out_effect(processed_audio, sample_rate, fade_out_duration)
|
160 |
-
|
161 |
-
if resample_sr != sample_rate and resample_sr > 0 and resample:
|
162 |
-
processed_audio = librosa.resample(processed_audio, orig_sr=sample_rate, target_sr=resample_sr)
|
163 |
-
sample_rate = resample_sr
|
164 |
-
except Exception as e:
|
165 |
-
raise RuntimeError(translations["apply_error"].format(e=e))
|
166 |
-
|
167 |
-
sf.write(output_path.replace(".wav", f".{export_format}"), processed_audio, sample_rate, format=export_format)
|
168 |
-
return output_path
|
169 |
-
|
170 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/convert.py
DELETED
@@ -1,1060 +0,0 @@
|
|
1 |
-
import gc
|
2 |
-
import re
|
3 |
-
import os
|
4 |
-
import sys
|
5 |
-
import time
|
6 |
-
import torch
|
7 |
-
import faiss
|
8 |
-
import shutil
|
9 |
-
import codecs
|
10 |
-
import pyworld
|
11 |
-
import librosa
|
12 |
-
import logging
|
13 |
-
import argparse
|
14 |
-
import warnings
|
15 |
-
import traceback
|
16 |
-
import torchcrepe
|
17 |
-
import subprocess
|
18 |
-
import parselmouth
|
19 |
-
import logging.handlers
|
20 |
-
|
21 |
-
import numpy as np
|
22 |
-
import soundfile as sf
|
23 |
-
import noisereduce as nr
|
24 |
-
import torch.nn.functional as F
|
25 |
-
import torch.multiprocessing as mp
|
26 |
-
|
27 |
-
from tqdm import tqdm
|
28 |
-
from scipy import signal
|
29 |
-
from torch import Tensor
|
30 |
-
from scipy.io import wavfile
|
31 |
-
from audio_upscaler import upscale
|
32 |
-
from distutils.util import strtobool
|
33 |
-
from fairseq import checkpoint_utils
|
34 |
-
from pydub import AudioSegment, silence
|
35 |
-
|
36 |
-
|
37 |
-
now_dir = os.getcwd()
|
38 |
-
sys.path.append(now_dir)
|
39 |
-
|
40 |
-
from main.configs.config import Config
|
41 |
-
from main.library.predictors.FCPE import FCPE
|
42 |
-
from main.library.predictors.RMVPE import RMVPE
|
43 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
44 |
-
|
45 |
-
|
46 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
47 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
48 |
-
|
49 |
-
logging.getLogger("wget").setLevel(logging.ERROR)
|
50 |
-
logging.getLogger("torch").setLevel(logging.ERROR)
|
51 |
-
logging.getLogger("faiss").setLevel(logging.ERROR)
|
52 |
-
logging.getLogger("httpx").setLevel(logging.ERROR)
|
53 |
-
logging.getLogger("fairseq").setLevel(logging.ERROR)
|
54 |
-
logging.getLogger("httpcore").setLevel(logging.ERROR)
|
55 |
-
logging.getLogger("faiss.loader").setLevel(logging.ERROR)
|
56 |
-
|
57 |
-
|
58 |
-
FILTER_ORDER = 5
|
59 |
-
CUTOFF_FREQUENCY = 48
|
60 |
-
SAMPLE_RATE = 16000
|
61 |
-
|
62 |
-
bh, ah = signal.butter(N=FILTER_ORDER, Wn=CUTOFF_FREQUENCY, btype="high", fs=SAMPLE_RATE)
|
63 |
-
input_audio_path2wav = {}
|
64 |
-
|
65 |
-
log_file = os.path.join("assets", "logs", "convert.log")
|
66 |
-
|
67 |
-
logger = logging.getLogger(__name__)
|
68 |
-
logger.propagate = False
|
69 |
-
|
70 |
-
translations = Config().translations
|
71 |
-
|
72 |
-
|
73 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
74 |
-
else:
|
75 |
-
console_handler = logging.StreamHandler()
|
76 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
77 |
-
|
78 |
-
console_handler.setFormatter(console_formatter)
|
79 |
-
console_handler.setLevel(logging.INFO)
|
80 |
-
|
81 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
82 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
83 |
-
|
84 |
-
file_handler.setFormatter(file_formatter)
|
85 |
-
file_handler.setLevel(logging.DEBUG)
|
86 |
-
|
87 |
-
logger.addHandler(console_handler)
|
88 |
-
logger.addHandler(file_handler)
|
89 |
-
logger.setLevel(logging.DEBUG)
|
90 |
-
|
91 |
-
|
92 |
-
def parse_arguments() -> tuple:
|
93 |
-
parser = argparse.ArgumentParser()
|
94 |
-
parser.add_argument("--pitch", type=int, default=0)
|
95 |
-
parser.add_argument("--filter_radius", type=int, default=3)
|
96 |
-
parser.add_argument("--index_rate", type=float, default=0.5)
|
97 |
-
parser.add_argument("--volume_envelope", type=float, default=1)
|
98 |
-
parser.add_argument("--protect", type=float, default=0.33)
|
99 |
-
parser.add_argument("--hop_length", type=int, default=64)
|
100 |
-
parser.add_argument( "--f0_method", type=str, default="rmvpe")
|
101 |
-
parser.add_argument("--input_path", type=str, required=True)
|
102 |
-
parser.add_argument("--output_path", type=str, default="./audios/output.wav")
|
103 |
-
parser.add_argument("--pth_path", type=str, required=True)
|
104 |
-
parser.add_argument("--index_path", type=str, required=True)
|
105 |
-
parser.add_argument("--f0_autotune", type=lambda x: bool(strtobool(x)), default=False)
|
106 |
-
parser.add_argument("--f0_autotune_strength", type=float, default=1)
|
107 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
108 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
109 |
-
parser.add_argument("--export_format", type=str, default="wav")
|
110 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
111 |
-
parser.add_argument("--upscale_audio", type=lambda x: bool(strtobool(x)), default=False)
|
112 |
-
parser.add_argument("--resample_sr", type=int, default=0)
|
113 |
-
parser.add_argument("--batch_process", type=lambda x: bool(strtobool(x)), default=False)
|
114 |
-
parser.add_argument("--batch_size", type=int, default=2)
|
115 |
-
parser.add_argument("--split_audio", type=lambda x: bool(strtobool(x)), default=False)
|
116 |
-
|
117 |
-
args = parser.parse_args()
|
118 |
-
return args
|
119 |
-
|
120 |
-
|
121 |
-
def main():
|
122 |
-
args = parse_arguments()
|
123 |
-
pitch = args.pitch
|
124 |
-
filter_radius = args.filter_radius
|
125 |
-
index_rate = args.index_rate
|
126 |
-
volume_envelope = args.volume_envelope
|
127 |
-
protect = args.protect
|
128 |
-
hop_length = args.hop_length
|
129 |
-
f0_method = args.f0_method
|
130 |
-
input_path = args.input_path
|
131 |
-
output_path = args.output_path
|
132 |
-
pth_path = args.pth_path
|
133 |
-
index_path = args.index_path
|
134 |
-
f0_autotune = args.f0_autotune
|
135 |
-
f0_autotune_strength = args.f0_autotune_strength
|
136 |
-
clean_audio = args.clean_audio
|
137 |
-
clean_strength = args.clean_strength
|
138 |
-
export_format = args.export_format
|
139 |
-
embedder_model = args.embedder_model
|
140 |
-
upscale_audio = args.upscale_audio
|
141 |
-
resample_sr = args.resample_sr
|
142 |
-
batch_process = args.batch_process
|
143 |
-
batch_size = args.batch_size
|
144 |
-
split_audio = args.split_audio
|
145 |
-
|
146 |
-
logger.debug(f"{translations['pitch']}: {pitch}")
|
147 |
-
logger.debug(f"{translations['filter_radius']}: {filter_radius}")
|
148 |
-
logger.debug(f"{translations['index_strength']} {index_rate}")
|
149 |
-
logger.debug(f"{translations['volume_envelope']}: {volume_envelope}")
|
150 |
-
logger.debug(f"{translations['protect']}: {protect}")
|
151 |
-
if f0_method == "crepe" or f0_method == "crepe-tiny": logger.debug(f"Hop length: {hop_length}")
|
152 |
-
logger.debug(f"{translations['f0_method']}: {f0_method}")
|
153 |
-
logger.debug(f"f0_method: {input_path}")
|
154 |
-
logger.debug(f"{translations['audio_path']}: {input_path}")
|
155 |
-
logger.debug(f"{translations['output_path']}: {output_path.replace('.wav', f'.{export_format}')}")
|
156 |
-
logger.debug(f"{translations['model_path']}: {pth_path}")
|
157 |
-
logger.debug(f"{translations['indexpath']}: {index_path}")
|
158 |
-
logger.debug(f"{translations['autotune']}: {f0_autotune}")
|
159 |
-
logger.debug(f"{translations['clear_audio']}: {clean_audio}")
|
160 |
-
if clean_audio: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
|
161 |
-
logger.debug(f"{translations['export_format']}: {export_format}")
|
162 |
-
logger.debug(f"{translations['hubert_model']}: {embedder_model}")
|
163 |
-
logger.debug(f"{translations['upscale_audio']}: {upscale_audio}")
|
164 |
-
if resample_sr != 0: logger.debug(f"{translations['sample_rate']}: {resample_sr}")
|
165 |
-
if split_audio: logger.debug(f"{translations['batch_process']}: {batch_process}")
|
166 |
-
if batch_process and split_audio: logger.debug(f"{translations['batch_size']}: {batch_size}")
|
167 |
-
logger.debug(f"{translations['split_audio']}: {split_audio}")
|
168 |
-
if f0_autotune: logger.debug(f"{translations['autotune_rate_info']}: {f0_autotune_strength}")
|
169 |
-
|
170 |
-
|
171 |
-
check_rmvpe_fcpe(f0_method)
|
172 |
-
check_hubert(embedder_model)
|
173 |
-
|
174 |
-
run_convert_script(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, input_path=input_path, output_path=output_path, pth_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, embedder_model=embedder_model, upscale_audio=upscale_audio, resample_sr=resample_sr, batch_process=batch_process, batch_size=batch_size, split_audio=split_audio)
|
175 |
-
|
176 |
-
|
177 |
-
def check_rmvpe_fcpe(method):
|
178 |
-
def download_rmvpe():
|
179 |
-
if not os.path.exists(os.path.join("assets", "model", "predictors", "rmvpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "rmvpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
|
180 |
-
|
181 |
-
def download_fcpe():
|
182 |
-
if not os.path.exists(os.path.join("assets", "model", "predictors", "fcpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "fcpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
|
183 |
-
|
184 |
-
if method == "rmvpe": download_rmvpe()
|
185 |
-
elif method == "fcpe": download_fcpe()
|
186 |
-
elif "hybrid" in method:
|
187 |
-
methods_str = re.search("hybrid\[(.+)\]", method)
|
188 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
189 |
-
|
190 |
-
for method in methods:
|
191 |
-
if method == "rmvpe": download_rmvpe()
|
192 |
-
elif method == "fcpe": download_fcpe()
|
193 |
-
|
194 |
-
|
195 |
-
def check_hubert(hubert):
|
196 |
-
if hubert == "contentvec_base" or hubert == "hubert_base" or hubert == "japanese_hubert_base" or hubert == "korean_hubert_base" or hubert == "chinese_hubert_base":
|
197 |
-
model_path = os.path.join(now_dir, "assets", "model", "embedders", hubert + '.pt')
|
198 |
-
|
199 |
-
if not os.path.exists(model_path): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + f"{hubert}.pt", "-P", os.path.join("assets", "model", "embedders")], check=True)
|
200 |
-
|
201 |
-
|
202 |
-
def load_audio_infer(file, sample_rate):
|
203 |
-
try:
|
204 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
205 |
-
if not os.path.isfile(file): raise FileNotFoundError(translations["not_found"].format(name=file))
|
206 |
-
|
207 |
-
audio, sr = sf.read(file)
|
208 |
-
|
209 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
210 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
|
211 |
-
except Exception as e:
|
212 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
213 |
-
|
214 |
-
return audio.flatten()
|
215 |
-
|
216 |
-
|
217 |
-
def process_audio(file_path, output_path):
|
218 |
-
try:
|
219 |
-
song = AudioSegment.from_file(file_path)
|
220 |
-
nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)
|
221 |
-
|
222 |
-
cut_files = []
|
223 |
-
time_stamps = []
|
224 |
-
|
225 |
-
min_chunk_duration = 30
|
226 |
-
|
227 |
-
for i, (start_i, end_i) in enumerate(nonsilent_parts):
|
228 |
-
chunk = song[start_i:end_i]
|
229 |
-
|
230 |
-
if len(chunk) >= min_chunk_duration:
|
231 |
-
chunk_file_path = os.path.join(output_path, f"chunk{i}.wav")
|
232 |
-
|
233 |
-
if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
|
234 |
-
chunk.export(chunk_file_path, format="wav")
|
235 |
-
|
236 |
-
cut_files.append(chunk_file_path)
|
237 |
-
time_stamps.append((start_i, end_i))
|
238 |
-
else: logger.debug(translations["skip_file"].format(i=i, chunk=len(chunk)))
|
239 |
-
|
240 |
-
logger.info(f"{translations['split_total']}: {len(cut_files)}")
|
241 |
-
return cut_files, time_stamps
|
242 |
-
except Exception as e:
|
243 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
244 |
-
|
245 |
-
|
246 |
-
def merge_audio(files_list, time_stamps, original_file_path, output_path, format):
|
247 |
-
try:
|
248 |
-
def extract_number(filename):
|
249 |
-
match = re.search(r'_(\d+)', filename)
|
250 |
-
|
251 |
-
return int(match.group(1)) if match else 0
|
252 |
-
|
253 |
-
files_list = sorted(files_list, key=extract_number)
|
254 |
-
total_duration = len(AudioSegment.from_file(original_file_path))
|
255 |
-
|
256 |
-
combined = AudioSegment.empty()
|
257 |
-
current_position = 0
|
258 |
-
|
259 |
-
for file, (start_i, end_i) in zip(files_list, time_stamps):
|
260 |
-
if start_i > current_position:
|
261 |
-
silence_duration = start_i - current_position
|
262 |
-
combined += AudioSegment.silent(duration=silence_duration)
|
263 |
-
|
264 |
-
combined += AudioSegment.from_file(file)
|
265 |
-
current_position = end_i
|
266 |
-
|
267 |
-
if current_position < total_duration: combined += AudioSegment.silent(duration=total_duration - current_position)
|
268 |
-
|
269 |
-
combined.export(output_path, format=format)
|
270 |
-
return output_path
|
271 |
-
except Exception as e:
|
272 |
-
raise RuntimeError(f"{translations['merge_error']}: {e}")
|
273 |
-
|
274 |
-
|
275 |
-
def run_batch_convert(params):
|
276 |
-
cvt = VoiceConverter()
|
277 |
-
|
278 |
-
path = params["path"]
|
279 |
-
audio_temp = params["audio_temp"]
|
280 |
-
export_format = params["export_format"]
|
281 |
-
cut_files = params["cut_files"]
|
282 |
-
pitch = params["pitch"]
|
283 |
-
filter_radius = params["filter_radius"]
|
284 |
-
index_rate = params["index_rate"]
|
285 |
-
volume_envelope = params["volume_envelope"]
|
286 |
-
protect = params["protect"]
|
287 |
-
hop_length = params["hop_length"]
|
288 |
-
f0_method = params["f0_method"]
|
289 |
-
pth_path = params["pth_path"]
|
290 |
-
index_path = params["index_path"]
|
291 |
-
f0_autotune = params["f0_autotune"]
|
292 |
-
f0_autotune_strength = params["f0_autotune_strength"]
|
293 |
-
clean_audio = params["clean_audio"]
|
294 |
-
clean_strength = params["clean_strength"]
|
295 |
-
upscale_audio = params["upscale_audio"]
|
296 |
-
embedder_model = params["embedder_model"]
|
297 |
-
resample_sr = params["resample_sr"]
|
298 |
-
|
299 |
-
|
300 |
-
segment_output_path = os.path.join(audio_temp, f"output_{cut_files.index(path)}.{export_format}")
|
301 |
-
if os.path.exists(segment_output_path): os.remove(segment_output_path)
|
302 |
-
|
303 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=path, audio_output_path=segment_output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
|
304 |
-
os.remove(path)
|
305 |
-
|
306 |
-
|
307 |
-
if os.path.exists(segment_output_path): return segment_output_path
|
308 |
-
else:
|
309 |
-
logger.warning(f"{translations['not_found_convert_file']}: {segment_output_path}")
|
310 |
-
sys.exit(1)
|
311 |
-
|
312 |
-
|
313 |
-
def run_convert_script(pitch, filter_radius, index_rate, volume_envelope, protect, hop_length, f0_method, input_path, output_path, pth_path, index_path, f0_autotune, f0_autotune_strength, clean_audio, clean_strength, export_format, upscale_audio, embedder_model, resample_sr, batch_process, batch_size, split_audio):
|
314 |
-
cvt = VoiceConverter()
|
315 |
-
start_time = time.time()
|
316 |
-
|
317 |
-
|
318 |
-
if not pth_path or not os.path.exists(pth_path) or os.path.isdir(pth_path) or not pth_path.endswith(".pth"):
|
319 |
-
logger.warning(translations["provide_file"].format(filename=translations["model"]))
|
320 |
-
sys.exit(1)
|
321 |
-
|
322 |
-
if not index_path or not os.path.exists(index_path) or os.path.isdir(index_path) or not index_path.endswith(".index"):
|
323 |
-
logger.warning(translations["provide_file"].format(filename=translations["index"]))
|
324 |
-
sys.exit(1)
|
325 |
-
|
326 |
-
|
327 |
-
output_dir = os.path.dirname(output_path)
|
328 |
-
output_dir = output_path if not output_dir else output_dir
|
329 |
-
|
330 |
-
if output_dir is None: output_dir = "audios"
|
331 |
-
|
332 |
-
if not os.path.exists(output_dir): os.makedirs(output_dir, exist_ok=True)
|
333 |
-
|
334 |
-
audio_temp = os.path.join("audios_temp")
|
335 |
-
if not os.path.exists(audio_temp) and split_audio: os.makedirs(audio_temp, exist_ok=True)
|
336 |
-
|
337 |
-
processed_segments = []
|
338 |
-
|
339 |
-
if os.path.isdir(input_path):
|
340 |
-
try:
|
341 |
-
logger.info(translations["convert_batch"])
|
342 |
-
|
343 |
-
audio_files = [f for f in os.listdir(input_path) if f.endswith(("wav", "mp3", "flac", "ogg", "opus", "m4a", "mp4", "aac", "alac", "wma", "aiff", "webm", "ac3"))]
|
344 |
-
if not audio_files:
|
345 |
-
logger.warning(translations["not_found_audio"])
|
346 |
-
sys.exit(1)
|
347 |
-
|
348 |
-
logger.info(translations["found_audio"].format(audio_files=len(audio_files)))
|
349 |
-
|
350 |
-
for audio in audio_files:
|
351 |
-
audio_path = os.path.join(input_path, audio)
|
352 |
-
output_audio = os.path.join(input_path, os.path.splitext(audio)[0] + f"_output.{export_format}")
|
353 |
-
|
354 |
-
if split_audio:
|
355 |
-
try:
|
356 |
-
cut_files, time_stamps = process_audio(audio_path, audio_temp)
|
357 |
-
num_threads = min(batch_size, len(cut_files))
|
358 |
-
|
359 |
-
params_list = [
|
360 |
-
{
|
361 |
-
"path": path,
|
362 |
-
"audio_temp": audio_temp,
|
363 |
-
"export_format": export_format,
|
364 |
-
"cut_files": cut_files,
|
365 |
-
"pitch": pitch,
|
366 |
-
"filter_radius": filter_radius,
|
367 |
-
"index_rate": index_rate,
|
368 |
-
"volume_envelope": volume_envelope,
|
369 |
-
"protect": protect,
|
370 |
-
"hop_length": hop_length,
|
371 |
-
"f0_method": f0_method,
|
372 |
-
"pth_path": pth_path,
|
373 |
-
"index_path": index_path,
|
374 |
-
"f0_autotune": f0_autotune,
|
375 |
-
"f0_autotune_strength": f0_autotune_strength,
|
376 |
-
"clean_audio": clean_audio,
|
377 |
-
"clean_strength": clean_strength,
|
378 |
-
"upscale_audio": upscale_audio,
|
379 |
-
"embedder_model": embedder_model,
|
380 |
-
"resample_sr": resample_sr
|
381 |
-
}
|
382 |
-
for path in cut_files
|
383 |
-
]
|
384 |
-
|
385 |
-
if batch_process:
|
386 |
-
with mp.Pool(processes=num_threads) as pool:
|
387 |
-
with tqdm(total=len(params_list), desc=translations["convert_audio"]) as pbar:
|
388 |
-
for results in pool.imap_unordered(run_batch_convert, params_list):
|
389 |
-
processed_segments.append(results)
|
390 |
-
pbar.update(1)
|
391 |
-
else:
|
392 |
-
for params in tqdm(params_list, desc=translations["convert_audio"]):
|
393 |
-
run_batch_convert(params)
|
394 |
-
|
395 |
-
merge_audio(processed_segments, time_stamps, audio_path, output_audio, export_format)
|
396 |
-
except Exception as e:
|
397 |
-
logger.error(translations["error_convert_batch"].format(e=e))
|
398 |
-
finally:
|
399 |
-
if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
|
400 |
-
else:
|
401 |
-
try:
|
402 |
-
logger.info(f"{translations['convert_audio']} '{audio_path}'...")
|
403 |
-
|
404 |
-
if os.path.exists(output_audio): os.remove(output_audio)
|
405 |
-
|
406 |
-
with tqdm(total=1, desc=translations["convert_audio"]) as pbar:
|
407 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=audio_path, audio_output_path=output_audio, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
|
408 |
-
pbar.update(1)
|
409 |
-
except Exception as e:
|
410 |
-
logger.error(translations["error_convert"].format(e=e))
|
411 |
-
|
412 |
-
elapsed_time = time.time() - start_time
|
413 |
-
logger.info(translations["convert_batch_success"].format(elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('.wav', f'.{export_format}')))
|
414 |
-
except Exception as e:
|
415 |
-
logger.error(translations["error_convert_batch_2"].format(e=e))
|
416 |
-
else:
|
417 |
-
logger.info(f"{translations['convert_audio']} '{input_path}'...")
|
418 |
-
|
419 |
-
if not os.path.exists(input_path):
|
420 |
-
logger.warning(translations["not_found_audio"])
|
421 |
-
sys.exit(1)
|
422 |
-
|
423 |
-
if os.path.isdir(output_path): output_path = os.path.join(output_path, f"output.{export_format}")
|
424 |
-
if os.path.exists(output_path): os.remove(output_path)
|
425 |
-
|
426 |
-
if split_audio:
|
427 |
-
try:
|
428 |
-
cut_files, time_stamps = process_audio(input_path, audio_temp)
|
429 |
-
num_threads = min(batch_size, len(cut_files))
|
430 |
-
|
431 |
-
params_list = [
|
432 |
-
{
|
433 |
-
"path": path,
|
434 |
-
"audio_temp": audio_temp,
|
435 |
-
"export_format": export_format,
|
436 |
-
"cut_files": cut_files,
|
437 |
-
"pitch": pitch,
|
438 |
-
"filter_radius": filter_radius,
|
439 |
-
"index_rate": index_rate,
|
440 |
-
"volume_envelope": volume_envelope,
|
441 |
-
"protect": protect,
|
442 |
-
"hop_length": hop_length,
|
443 |
-
"f0_method": f0_method,
|
444 |
-
"pth_path": pth_path,
|
445 |
-
"index_path": index_path,
|
446 |
-
"f0_autotune": f0_autotune,
|
447 |
-
"f0_autotune_strength": f0_autotune_strength,
|
448 |
-
"clean_audio": clean_audio,
|
449 |
-
"clean_strength": clean_strength,
|
450 |
-
"upscale_audio": upscale_audio,
|
451 |
-
"embedder_model": embedder_model,
|
452 |
-
"resample_sr": resample_sr
|
453 |
-
}
|
454 |
-
for path in cut_files
|
455 |
-
]
|
456 |
-
|
457 |
-
if batch_process:
|
458 |
-
with mp.Pool(processes=num_threads) as pool:
|
459 |
-
with tqdm(total=len(params_list), desc=translations["convert_audio"]) as pbar:
|
460 |
-
for results in pool.imap_unordered(run_batch_convert, params_list):
|
461 |
-
processed_segments.append(results)
|
462 |
-
pbar.update(1)
|
463 |
-
else:
|
464 |
-
for params in tqdm(params_list, desc=translations["convert_audio"]):
|
465 |
-
run_batch_convert(params)
|
466 |
-
|
467 |
-
merge_audio(processed_segments, time_stamps, input_path, output_path.replace(".wav", f".{export_format}"), export_format)
|
468 |
-
except Exception as e:
|
469 |
-
logger.error(translations["error_convert_batch"].format(e=e))
|
470 |
-
finally:
|
471 |
-
if os.path.exists(audio_temp): shutil.rmtree(audio_temp, ignore_errors=True)
|
472 |
-
else:
|
473 |
-
try:
|
474 |
-
with tqdm(total=1, desc=translations["convert_audio"]) as pbar:
|
475 |
-
cvt.convert_audio(pitch=pitch, filter_radius=filter_radius, index_rate=index_rate, volume_envelope=volume_envelope, protect=protect, hop_length=hop_length, f0_method=f0_method, audio_input_path=input_path, audio_output_path=output_path, model_path=pth_path, index_path=index_path, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength, clean_audio=clean_audio, clean_strength=clean_strength, export_format=export_format, upscale_audio=upscale_audio, embedder_model=embedder_model, resample_sr=resample_sr)
|
476 |
-
pbar.update(1)
|
477 |
-
except Exception as e:
|
478 |
-
logger.error(translations["error_convert"].format(e=e))
|
479 |
-
|
480 |
-
elapsed_time = time.time() - start_time
|
481 |
-
logger.info(translations["convert_audio_success"].format(input_path=input_path, elapsed_time=f"{elapsed_time:.2f}", output_path=output_path.replace('.wav', f'.{export_format}')))
|
482 |
-
|
483 |
-
|
484 |
-
def change_rms(source_audio: np.ndarray, source_rate: int, target_audio: np.ndarray, target_rate: int, rate: float) -> np.ndarray:
|
485 |
-
rms1 = librosa.feature.rms(
|
486 |
-
y=source_audio,
|
487 |
-
frame_length=source_rate // 2 * 2,
|
488 |
-
hop_length=source_rate // 2,
|
489 |
-
)
|
490 |
-
|
491 |
-
rms2 = librosa.feature.rms(
|
492 |
-
y=target_audio,
|
493 |
-
frame_length=target_rate // 2 * 2,
|
494 |
-
hop_length=target_rate // 2,
|
495 |
-
)
|
496 |
-
|
497 |
-
rms1 = F.interpolate(
|
498 |
-
torch.from_numpy(rms1).float().unsqueeze(0),
|
499 |
-
size=target_audio.shape[0],
|
500 |
-
mode="linear",
|
501 |
-
).squeeze()
|
502 |
-
|
503 |
-
rms2 = F.interpolate(
|
504 |
-
torch.from_numpy(rms2).float().unsqueeze(0),
|
505 |
-
size=target_audio.shape[0],
|
506 |
-
mode="linear",
|
507 |
-
).squeeze()
|
508 |
-
|
509 |
-
rms2 = torch.maximum(rms2, torch.zeros_like(rms2) + 1e-6)
|
510 |
-
|
511 |
-
|
512 |
-
adjusted_audio = (target_audio * (torch.pow(rms1, 1 - rate) * torch.pow(rms2, rate - 1)).numpy())
|
513 |
-
return adjusted_audio
|
514 |
-
|
515 |
-
|
516 |
-
class Autotune:
|
517 |
-
def __init__(self, ref_freqs):
|
518 |
-
self.ref_freqs = ref_freqs
|
519 |
-
self.note_dict = self.ref_freqs
|
520 |
-
|
521 |
-
|
522 |
-
def autotune_f0(self, f0, f0_autotune_strength):
|
523 |
-
autotuned_f0 = np.zeros_like(f0)
|
524 |
-
|
525 |
-
|
526 |
-
for i, freq in enumerate(f0):
|
527 |
-
closest_note = min(self.note_dict, key=lambda x: abs(x - freq))
|
528 |
-
autotuned_f0[i] = freq + (closest_note - freq) * f0_autotune_strength
|
529 |
-
|
530 |
-
return autotuned_f0
|
531 |
-
|
532 |
-
|
533 |
-
class VC:
|
534 |
-
def __init__(self, tgt_sr, config):
|
535 |
-
self.x_pad = config.x_pad
|
536 |
-
self.x_query = config.x_query
|
537 |
-
self.x_center = config.x_center
|
538 |
-
self.x_max = config.x_max
|
539 |
-
self.is_half = config.is_half
|
540 |
-
self.sample_rate = 16000
|
541 |
-
self.window = 160
|
542 |
-
self.t_pad = self.sample_rate * self.x_pad
|
543 |
-
self.t_pad_tgt = tgt_sr * self.x_pad
|
544 |
-
self.t_pad2 = self.t_pad * 2
|
545 |
-
self.t_query = self.sample_rate * self.x_query
|
546 |
-
self.t_center = self.sample_rate * self.x_center
|
547 |
-
self.t_max = self.sample_rate * self.x_max
|
548 |
-
self.time_step = self.window / self.sample_rate * 1000
|
549 |
-
self.f0_min = 50
|
550 |
-
self.f0_max = 1100
|
551 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
552 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
553 |
-
self.device = config.device
|
554 |
-
self.ref_freqs = [
|
555 |
-
49.00,
|
556 |
-
51.91,
|
557 |
-
55.00,
|
558 |
-
58.27,
|
559 |
-
61.74,
|
560 |
-
65.41,
|
561 |
-
69.30,
|
562 |
-
73.42,
|
563 |
-
77.78,
|
564 |
-
82.41,
|
565 |
-
87.31,
|
566 |
-
92.50,
|
567 |
-
98.00,
|
568 |
-
103.83,
|
569 |
-
110.00,
|
570 |
-
116.54,
|
571 |
-
123.47,
|
572 |
-
130.81,
|
573 |
-
138.59,
|
574 |
-
146.83,
|
575 |
-
155.56,
|
576 |
-
164.81,
|
577 |
-
174.61,
|
578 |
-
185.00,
|
579 |
-
196.00,
|
580 |
-
207.65,
|
581 |
-
220.00,
|
582 |
-
233.08,
|
583 |
-
246.94,
|
584 |
-
261.63,
|
585 |
-
277.18,
|
586 |
-
293.66,
|
587 |
-
311.13,
|
588 |
-
329.63,
|
589 |
-
349.23,
|
590 |
-
369.99,
|
591 |
-
392.00,
|
592 |
-
415.30,
|
593 |
-
440.00,
|
594 |
-
466.16,
|
595 |
-
493.88,
|
596 |
-
523.25,
|
597 |
-
554.37,
|
598 |
-
587.33,
|
599 |
-
622.25,
|
600 |
-
659.25,
|
601 |
-
698.46,
|
602 |
-
739.99,
|
603 |
-
783.99,
|
604 |
-
830.61,
|
605 |
-
880.00,
|
606 |
-
932.33,
|
607 |
-
987.77,
|
608 |
-
1046.50
|
609 |
-
]
|
610 |
-
self.autotune = Autotune(self.ref_freqs)
|
611 |
-
self.note_dict = self.autotune.note_dict
|
612 |
-
|
613 |
-
|
614 |
-
def get_f0_crepe(self, x, f0_min, f0_max, p_len, hop_length, model="full"):
|
615 |
-
x = x.astype(np.float32)
|
616 |
-
x /= np.quantile(np.abs(x), 0.999)
|
617 |
-
|
618 |
-
audio = torch.from_numpy(x).to(self.device, copy=True)
|
619 |
-
audio = torch.unsqueeze(audio, dim=0)
|
620 |
-
|
621 |
-
|
622 |
-
if audio.ndim == 2 and audio.shape[0] > 1: audio = torch.mean(audio, dim=0, keepdim=True).detach()
|
623 |
-
|
624 |
-
audio = audio.detach()
|
625 |
-
pitch: Tensor = torchcrepe.predict(audio, self.sample_rate, hop_length, f0_min, f0_max, model, batch_size=hop_length * 2, device=self.device, pad=True)
|
626 |
-
|
627 |
-
p_len = p_len or x.shape[0] // hop_length
|
628 |
-
source = np.array(pitch.squeeze(0).cpu().float().numpy())
|
629 |
-
source[source < 0.001] = np.nan
|
630 |
-
|
631 |
-
target = np.interp(
|
632 |
-
np.arange(0, len(source) * p_len, len(source)) / p_len,
|
633 |
-
np.arange(0, len(source)),
|
634 |
-
source,
|
635 |
-
)
|
636 |
-
|
637 |
-
f0 = np.nan_to_num(target)
|
638 |
-
return f0
|
639 |
-
|
640 |
-
|
641 |
-
def get_f0_hybrid(self, methods_str, x, f0_min, f0_max, p_len, hop_length, filter_radius):
|
642 |
-
methods_str = re.search("hybrid\[(.+)\]", methods_str)
|
643 |
-
if methods_str: methods = [method.strip() for method in methods_str.group(1).split("+")]
|
644 |
-
|
645 |
-
f0_computation_stack = []
|
646 |
-
logger.debug(translations["hybrid_methods"].format(methods=methods))
|
647 |
-
|
648 |
-
x = x.astype(np.float32)
|
649 |
-
x /= np.quantile(np.abs(x), 0.999)
|
650 |
-
|
651 |
-
|
652 |
-
for method in methods:
|
653 |
-
f0 = None
|
654 |
-
|
655 |
-
|
656 |
-
if method == "pm":
|
657 |
-
f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
|
658 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
659 |
-
|
660 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
661 |
-
elif method == 'dio':
|
662 |
-
f0, t = pyworld.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
663 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
664 |
-
|
665 |
-
f0 = signal.medfilt(f0, 3)
|
666 |
-
elif method == "crepe-tiny":
|
667 |
-
f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length), "tiny")
|
668 |
-
elif method == "crepe":
|
669 |
-
f0 = self.get_f0_crepe(x, f0_min, f0_max, p_len, int(hop_length))
|
670 |
-
elif method == "fcpe":
|
671 |
-
self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=int(f0_min), f0_max=int(f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03)
|
672 |
-
f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
|
673 |
-
|
674 |
-
del self.model_fcpe
|
675 |
-
gc.collect()
|
676 |
-
elif method == "rmvpe":
|
677 |
-
f0 = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=self.is_half, device=self.device).infer_from_audio(x, thred=0.03)
|
678 |
-
f0 = f0[1:]
|
679 |
-
elif method == "harvest":
|
680 |
-
f0, t = pyworld.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
681 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
682 |
-
|
683 |
-
if filter_radius > 2: f0 = signal.medfilt(f0, 3)
|
684 |
-
else: raise ValueError(translations["method_not_valid"])
|
685 |
-
|
686 |
-
f0_computation_stack.append(f0)
|
687 |
-
|
688 |
-
resampled_stack = []
|
689 |
-
|
690 |
-
for f0 in f0_computation_stack:
|
691 |
-
resampled_f0 = np.interp(np.linspace(0, len(f0), p_len), np.arange(len(f0)), f0)
|
692 |
-
resampled_stack.append(resampled_f0)
|
693 |
-
|
694 |
-
f0_median_hybrid = resampled_stack[0] if len(resampled_stack) == 1 else np.nanmedian(np.vstack(resampled_stack), axis=0)
|
695 |
-
return f0_median_hybrid
|
696 |
-
|
697 |
-
|
698 |
-
def get_f0(self, input_audio_path, x, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength):
|
699 |
-
global input_audio_path2wav
|
700 |
-
|
701 |
-
|
702 |
-
if f0_method == "pm":
|
703 |
-
f0 = (parselmouth.Sound(x, self.sample_rate).to_pitch_ac(time_step=self.window / self.sample_rate * 1000 / 1000, voicing_threshold=0.6, pitch_floor=self.f0_min, pitch_ceiling=self.f0_max).selected_array["frequency"])
|
704 |
-
pad_size = (p_len - len(f0) + 1) // 2
|
705 |
-
|
706 |
-
if pad_size > 0 or p_len - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, p_len - len(f0) - pad_size]], mode="constant")
|
707 |
-
elif f0_method == "dio":
|
708 |
-
f0, t = pyworld.dio(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
709 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
710 |
-
|
711 |
-
f0 = signal.medfilt(f0, 3)
|
712 |
-
elif f0_method == "crepe-tiny":
|
713 |
-
f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length), "tiny")
|
714 |
-
elif f0_method == "crepe":
|
715 |
-
f0 = self.get_f0_crepe(x, self.f0_min, self.f0_max, p_len, int(hop_length))
|
716 |
-
elif f0_method == "fcpe":
|
717 |
-
self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=int(self.f0_min), f0_max=int(self.f0_max), dtype=torch.float32, device=self.device, sample_rate=self.sample_rate, threshold=0.03)
|
718 |
-
f0 = self.model_fcpe.compute_f0(x, p_len=p_len)
|
719 |
-
|
720 |
-
del self.model_fcpe
|
721 |
-
gc.collect()
|
722 |
-
elif f0_method == "rmvpe":
|
723 |
-
f0 = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=self.is_half, device=self.device).infer_from_audio(x, thred=0.03)
|
724 |
-
elif f0_method == "harvest":
|
725 |
-
f0, t = pyworld.harvest(x.astype(np.double), fs=self.sample_rate, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=10)
|
726 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sample_rate)
|
727 |
-
|
728 |
-
if filter_radius > 2: f0 = signal.medfilt(f0, 3)
|
729 |
-
elif "hybrid" in f0_method:
|
730 |
-
input_audio_path2wav[input_audio_path] = x.astype(np.double)
|
731 |
-
f0 = self.get_f0_hybrid(f0_method, x, self.f0_min, self.f0_max, p_len, hop_length, filter_radius)
|
732 |
-
else: raise ValueError(translations["method_not_valid"])
|
733 |
-
|
734 |
-
if f0_autotune: f0 = Autotune.autotune_f0(self, f0, f0_autotune_strength)
|
735 |
-
|
736 |
-
f0 *= pow(2, pitch / 12)
|
737 |
-
|
738 |
-
f0bak = f0.copy()
|
739 |
-
|
740 |
-
f0_mel = 1127 * np.log(1 + f0 / 700)
|
741 |
-
f0_mel[f0_mel > 0] = (f0_mel[f0_mel > 0] - self.f0_mel_min) * 254 / (self.f0_mel_max - self.f0_mel_min) + 1
|
742 |
-
f0_mel[f0_mel <= 1] = 1
|
743 |
-
f0_mel[f0_mel > 255] = 255
|
744 |
-
|
745 |
-
f0_coarse = np.rint(f0_mel).astype(np.int32)
|
746 |
-
return f0_coarse, f0bak
|
747 |
-
|
748 |
-
|
749 |
-
def voice_conversion(self, model, net_g, sid, audio0, pitch, pitchf, index, big_npy, index_rate, version, protect):
|
750 |
-
pitch_guidance = pitch != None and pitchf != None
|
751 |
-
|
752 |
-
feats = (torch.from_numpy(audio0).half() if self.is_half else torch.from_numpy(audio0).float())
|
753 |
-
|
754 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
755 |
-
assert feats.dim() == 1, feats.dim()
|
756 |
-
|
757 |
-
feats = feats.view(1, -1)
|
758 |
-
|
759 |
-
padding_mask = torch.BoolTensor(feats.shape).to(self.device).fill_(False)
|
760 |
-
|
761 |
-
inputs = {
|
762 |
-
"source": feats.to(self.device),
|
763 |
-
"padding_mask": padding_mask,
|
764 |
-
"output_layer": 9 if version == "v1" else 12,
|
765 |
-
}
|
766 |
-
|
767 |
-
with torch.no_grad():
|
768 |
-
logits = model.extract_features(**inputs)
|
769 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
770 |
-
|
771 |
-
if protect < 0.5 and pitch_guidance: feats0 = feats.clone()
|
772 |
-
|
773 |
-
if (not isinstance(index, type(None)) and not isinstance(big_npy, type(None)) and index_rate != 0):
|
774 |
-
npy = feats[0].cpu().numpy()
|
775 |
-
|
776 |
-
if self.is_half: npy = npy.astype("float32")
|
777 |
-
|
778 |
-
score, ix = index.search(npy, k=8)
|
779 |
-
|
780 |
-
weight = np.square(1 / score)
|
781 |
-
weight /= weight.sum(axis=1, keepdims=True)
|
782 |
-
|
783 |
-
npy = np.sum(big_npy[ix] * np.expand_dims(weight, axis=2), axis=1)
|
784 |
-
|
785 |
-
if self.is_half: npy = npy.astype("float16")
|
786 |
-
|
787 |
-
feats = (torch.from_numpy(npy).unsqueeze(0).to(self.device) * index_rate + (1 - index_rate) * feats)
|
788 |
-
|
789 |
-
feats = F.interpolate(feats.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
790 |
-
|
791 |
-
if protect < 0.5 and pitch_guidance: feats0 = F.interpolate(feats0.permute(0, 2, 1), scale_factor=2).permute(0, 2, 1)
|
792 |
-
|
793 |
-
p_len = audio0.shape[0] // self.window
|
794 |
-
|
795 |
-
if feats.shape[1] < p_len:
|
796 |
-
p_len = feats.shape[1]
|
797 |
-
|
798 |
-
if pitch_guidance:
|
799 |
-
pitch = pitch[:, :p_len]
|
800 |
-
pitchf = pitchf[:, :p_len]
|
801 |
-
|
802 |
-
if protect < 0.5 and pitch_guidance:
|
803 |
-
pitchff = pitchf.clone()
|
804 |
-
pitchff[pitchf > 0] = 1
|
805 |
-
pitchff[pitchf < 1] = protect
|
806 |
-
pitchff = pitchff.unsqueeze(-1)
|
807 |
-
|
808 |
-
feats = feats * pitchff + feats0 * (1 - pitchff)
|
809 |
-
feats = feats.to(feats0.dtype)
|
810 |
-
|
811 |
-
p_len = torch.tensor([p_len], device=self.device).long()
|
812 |
-
|
813 |
-
with torch.no_grad():
|
814 |
-
audio1 = ((net_g.infer(feats, p_len, pitch, pitchf, sid)[0][0, 0]).data.cpu().float().numpy()) if pitch_guidance else ((net_g.infer(feats, p_len, sid)[0][0, 0]).data.cpu().float().numpy())
|
815 |
-
|
816 |
-
del feats, p_len, padding_mask
|
817 |
-
|
818 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
819 |
-
return audio1
|
820 |
-
|
821 |
-
|
822 |
-
def pipeline(self, model, net_g, sid, audio, input_audio_path, pitch, f0_method, file_index, index_rate, pitch_guidance, filter_radius, tgt_sr, resample_sr, volume_envelope, version, protect, hop_length, f0_autotune, f0_autotune_strength):
|
823 |
-
if file_index != "" and os.path.exists(file_index) and index_rate != 0:
|
824 |
-
try:
|
825 |
-
index = faiss.read_index(file_index)
|
826 |
-
big_npy = index.reconstruct_n(0, index.ntotal)
|
827 |
-
except Exception as e:
|
828 |
-
logger.error(translations["read_faiss_index_error"].format(e=e))
|
829 |
-
index = big_npy = None
|
830 |
-
else: index = big_npy = None
|
831 |
-
|
832 |
-
audio = signal.filtfilt(bh, ah, audio)
|
833 |
-
audio_pad = np.pad(audio, (self.window // 2, self.window // 2), mode="reflect")
|
834 |
-
opt_ts = []
|
835 |
-
|
836 |
-
if audio_pad.shape[0] > self.t_max:
|
837 |
-
audio_sum = np.zeros_like(audio)
|
838 |
-
|
839 |
-
for i in range(self.window):
|
840 |
-
audio_sum += audio_pad[i : i - self.window]
|
841 |
-
|
842 |
-
for t in range(self.t_center, audio.shape[0], self.t_center):
|
843 |
-
opt_ts.append(t - self.t_query + np.where(np.abs(audio_sum[t - self.t_query : t + self.t_query]) == np.abs(audio_sum[t - self.t_query : t + self.t_query]).min())[0][0])
|
844 |
-
|
845 |
-
s = 0
|
846 |
-
audio_opt = []
|
847 |
-
t = None
|
848 |
-
|
849 |
-
audio_pad = np.pad(audio, (self.t_pad, self.t_pad), mode="reflect")
|
850 |
-
p_len = audio_pad.shape[0] // self.window
|
851 |
-
|
852 |
-
sid = torch.tensor(sid, device=self.device).unsqueeze(0).long()
|
853 |
-
|
854 |
-
if pitch_guidance:
|
855 |
-
pitch, pitchf = self.get_f0(input_audio_path, audio_pad, p_len, pitch, f0_method, filter_radius, hop_length, f0_autotune, f0_autotune_strength)
|
856 |
-
pitch = pitch[:p_len]
|
857 |
-
pitchf = pitchf[:p_len]
|
858 |
-
|
859 |
-
if self.device == "mps": pitchf = pitchf.astype(np.float32)
|
860 |
-
|
861 |
-
pitch = torch.tensor(pitch, device=self.device).unsqueeze(0).long()
|
862 |
-
pitchf = torch.tensor(pitchf, device=self.device).unsqueeze(0).float()
|
863 |
-
|
864 |
-
for t in opt_ts:
|
865 |
-
t = t // self.window * self.window
|
866 |
-
|
867 |
-
if pitch_guidance: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], pitch[:, s // self.window : (t + self.t_pad2) // self.window], pitchf[:, s // self.window : (t + self.t_pad2) // self.window], index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
868 |
-
else: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[s : t + self.t_pad2 + self.window], None, None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
869 |
-
|
870 |
-
s = t
|
871 |
-
|
872 |
-
if pitch_guidance: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], pitch[:, t // self.window :] if t is not None else pitch, pitchf[:, t // self.window :] if t is not None else pitchf, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
873 |
-
else: audio_opt.append(self.voice_conversion(model, net_g, sid, audio_pad[t:], None, None, index, big_npy, index_rate, version, protect)[self.t_pad_tgt : -self.t_pad_tgt])
|
874 |
-
|
875 |
-
audio_opt = np.concatenate(audio_opt)
|
876 |
-
|
877 |
-
if volume_envelope != 1: audio_opt = change_rms(audio, self.sample_rate, audio_opt, tgt_sr, volume_envelope)
|
878 |
-
if resample_sr >= self.sample_rate and tgt_sr != resample_sr: audio_opt = librosa.resample(audio_opt, orig_sr=tgt_sr, target_sr=resample_sr)
|
879 |
-
|
880 |
-
audio_max = np.abs(audio_opt).max() / 0.99
|
881 |
-
max_int16 = 32768
|
882 |
-
|
883 |
-
if audio_max > 1: max_int16 /= audio_max
|
884 |
-
|
885 |
-
audio_opt = (audio_opt * max_int16).astype(np.int16)
|
886 |
-
|
887 |
-
if pitch_guidance: del pitch, pitchf
|
888 |
-
del sid
|
889 |
-
|
890 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
891 |
-
return audio_opt
|
892 |
-
|
893 |
-
|
894 |
-
class VoiceConverter:
|
895 |
-
def __init__(self):
|
896 |
-
self.config = Config()
|
897 |
-
self.hubert_model = (None)
|
898 |
-
|
899 |
-
self.tgt_sr = None
|
900 |
-
self.net_g = None
|
901 |
-
|
902 |
-
self.vc = None
|
903 |
-
self.cpt = None
|
904 |
-
|
905 |
-
self.version = None
|
906 |
-
self.n_spk = None
|
907 |
-
|
908 |
-
self.use_f0 = None
|
909 |
-
self.loaded_model = None
|
910 |
-
|
911 |
-
|
912 |
-
def load_hubert(self, embedder_model):
|
913 |
-
try:
|
914 |
-
models, _, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')], suffix="")
|
915 |
-
except Exception as e:
|
916 |
-
raise ImportError(translations["read_model_error"].format(e=e))
|
917 |
-
|
918 |
-
self.hubert_model = models[0].to(self.config.device)
|
919 |
-
self.hubert_model = (self.hubert_model.half() if self.config.is_half else self.hubert_model.float())
|
920 |
-
self.hubert_model.eval()
|
921 |
-
|
922 |
-
|
923 |
-
@staticmethod
|
924 |
-
def remove_audio_noise(input_audio_path, reduction_strength=0.7):
|
925 |
-
try:
|
926 |
-
rate, data = wavfile.read(input_audio_path)
|
927 |
-
reduced_noise = nr.reduce_noise(y=data, sr=rate, prop_decrease=reduction_strength)
|
928 |
-
|
929 |
-
return reduced_noise
|
930 |
-
except Exception as e:
|
931 |
-
logger.error(translations["denoise_error"].format(e=e))
|
932 |
-
return None
|
933 |
-
|
934 |
-
|
935 |
-
@staticmethod
|
936 |
-
def convert_audio_format(input_path, output_path, output_format):
|
937 |
-
try:
|
938 |
-
if output_format != "wav":
|
939 |
-
logger.debug(translations["change_format"].format(output_format=output_format))
|
940 |
-
audio, sample_rate = sf.read(input_path)
|
941 |
-
|
942 |
-
|
943 |
-
common_sample_rates = [
|
944 |
-
8000,
|
945 |
-
11025,
|
946 |
-
12000,
|
947 |
-
16000,
|
948 |
-
22050,
|
949 |
-
24000,
|
950 |
-
32000,
|
951 |
-
44100,
|
952 |
-
48000
|
953 |
-
]
|
954 |
-
|
955 |
-
target_sr = min(common_sample_rates, key=lambda x: abs(x - sample_rate))
|
956 |
-
audio = librosa.resample(audio, orig_sr=sample_rate, target_sr=target_sr)
|
957 |
-
|
958 |
-
sf.write(output_path, audio, target_sr, format=output_format)
|
959 |
-
|
960 |
-
return output_path
|
961 |
-
except Exception as e:
|
962 |
-
raise RuntimeError(translations["change_format_error"].format(e=e))
|
963 |
-
|
964 |
-
|
965 |
-
def convert_audio(self, audio_input_path, audio_output_path, model_path, index_path, embedder_model, pitch, f0_method, index_rate, volume_envelope, protect, hop_length, f0_autotune, f0_autotune_strength, filter_radius, clean_audio, clean_strength, export_format, upscale_audio, resample_sr = 0, sid = 0):
|
966 |
-
self.get_vc(model_path, sid)
|
967 |
-
|
968 |
-
try:
|
969 |
-
if upscale_audio: upscale(audio_input_path, audio_input_path)
|
970 |
-
|
971 |
-
audio = load_audio_infer(audio_input_path, 16000)
|
972 |
-
|
973 |
-
audio_max = np.abs(audio).max() / 0.95
|
974 |
-
|
975 |
-
|
976 |
-
if audio_max > 1: audio /= audio_max
|
977 |
-
|
978 |
-
if not self.hubert_model:
|
979 |
-
if not os.path.exists(os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')): raise FileNotFoundError(f"Không tìm thấy mô hình: {embedder_model}")
|
980 |
-
|
981 |
-
self.load_hubert(embedder_model)
|
982 |
-
|
983 |
-
if self.tgt_sr != resample_sr >= 16000: self.tgt_sr = resample_sr
|
984 |
-
|
985 |
-
file_index = (index_path.strip().strip('"').strip("\n").strip('"').strip().replace("trained", "added"))
|
986 |
-
|
987 |
-
audio_opt = self.vc.pipeline(model=self.hubert_model, net_g=self.net_g, sid=sid, audio=audio, input_audio_path=audio_input_path, pitch=pitch, f0_method=f0_method, file_index=file_index, index_rate=index_rate, pitch_guidance=self.use_f0, filter_radius=filter_radius, tgt_sr=self.tgt_sr, resample_sr=resample_sr, volume_envelope=volume_envelope, version=self.version, protect=protect, hop_length=hop_length, f0_autotune=f0_autotune, f0_autotune_strength=f0_autotune_strength)
|
988 |
-
|
989 |
-
if audio_output_path: sf.write(audio_output_path, audio_opt, self.tgt_sr, format="wav")
|
990 |
-
|
991 |
-
if clean_audio:
|
992 |
-
cleaned_audio = self.remove_audio_noise(audio_output_path, clean_strength)
|
993 |
-
if cleaned_audio is not None: sf.write(audio_output_path, cleaned_audio, self.tgt_sr, format="wav")
|
994 |
-
|
995 |
-
output_path_format = audio_output_path.replace(".wav", f".{export_format}")
|
996 |
-
audio_output_path = self.convert_audio_format(audio_output_path, output_path_format, export_format)
|
997 |
-
except Exception as e:
|
998 |
-
logger.error(translations["error_convert"].format(e=e))
|
999 |
-
logger.error(traceback.format_exc())
|
1000 |
-
|
1001 |
-
|
1002 |
-
def get_vc(self, weight_root, sid):
|
1003 |
-
if sid == "" or sid == []:
|
1004 |
-
self.cleanup_model()
|
1005 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
1006 |
-
|
1007 |
-
if not self.loaded_model or self.loaded_model != weight_root:
|
1008 |
-
self.load_model(weight_root)
|
1009 |
-
|
1010 |
-
if self.cpt is not None:
|
1011 |
-
self.setup_network()
|
1012 |
-
self.setup_vc_instance()
|
1013 |
-
|
1014 |
-
self.loaded_model = weight_root
|
1015 |
-
|
1016 |
-
|
1017 |
-
def cleanup_model(self):
|
1018 |
-
if self.hubert_model is not None:
|
1019 |
-
del self.net_g, self.n_spk, self.vc, self.hubert_model, self.tgt_sr
|
1020 |
-
|
1021 |
-
self.hubert_model = self.net_g = self.n_spk = self.vc = self.tgt_sr = None
|
1022 |
-
|
1023 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
1024 |
-
|
1025 |
-
del self.net_g, self.cpt
|
1026 |
-
|
1027 |
-
if torch.cuda.is_available(): torch.cuda.empty_cache()
|
1028 |
-
self.cpt = None
|
1029 |
-
|
1030 |
-
|
1031 |
-
def load_model(self, weight_root):
|
1032 |
-
self.cpt = (torch.load(weight_root, map_location="cpu") if os.path.isfile(weight_root) else None)
|
1033 |
-
|
1034 |
-
|
1035 |
-
def setup_network(self):
|
1036 |
-
if self.cpt is not None:
|
1037 |
-
self.tgt_sr = self.cpt["config"][-1]
|
1038 |
-
self.cpt["config"][-3] = self.cpt["weight"]["emb_g.weight"].shape[0]
|
1039 |
-
self.use_f0 = self.cpt.get("f0", 1)
|
1040 |
-
|
1041 |
-
self.version = self.cpt.get("version", "v1")
|
1042 |
-
self.text_enc_hidden_dim = 768 if self.version == "v2" else 256
|
1043 |
-
|
1044 |
-
self.net_g = Synthesizer(*self.cpt["config"], use_f0=self.use_f0, text_enc_hidden_dim=self.text_enc_hidden_dim, is_half=self.config.is_half)
|
1045 |
-
|
1046 |
-
del self.net_g.enc_q
|
1047 |
-
|
1048 |
-
self.net_g.load_state_dict(self.cpt["weight"], strict=False)
|
1049 |
-
self.net_g.eval().to(self.config.device)
|
1050 |
-
self.net_g = (self.net_g.half() if self.config.is_half else self.net_g.float())
|
1051 |
-
|
1052 |
-
|
1053 |
-
def setup_vc_instance(self):
|
1054 |
-
if self.cpt is not None:
|
1055 |
-
self.vc = VC(self.tgt_sr, self.config)
|
1056 |
-
self.n_spk = self.cpt["config"][-3]
|
1057 |
-
|
1058 |
-
if __name__ == "__main__":
|
1059 |
-
mp.set_start_method("spawn", force=True)
|
1060 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_dataset.py
DELETED
@@ -1,370 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import yt_dlp
|
6 |
-
import shutil
|
7 |
-
import librosa
|
8 |
-
import logging
|
9 |
-
import argparse
|
10 |
-
import warnings
|
11 |
-
import logging.handlers
|
12 |
-
|
13 |
-
import soundfile as sf
|
14 |
-
import noisereduce as nr
|
15 |
-
|
16 |
-
from distutils.util import strtobool
|
17 |
-
from pydub import AudioSegment, silence
|
18 |
-
|
19 |
-
|
20 |
-
now_dir = os.getcwd()
|
21 |
-
sys.path.append(now_dir)
|
22 |
-
|
23 |
-
from main.configs.config import Config
|
24 |
-
from main.library.algorithm.separator import Separator
|
25 |
-
|
26 |
-
|
27 |
-
translations = Config().translations
|
28 |
-
|
29 |
-
|
30 |
-
log_file = os.path.join("assets", "logs", "create_dataset.log")
|
31 |
-
logger = logging.getLogger(__name__)
|
32 |
-
|
33 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
34 |
-
else:
|
35 |
-
console_handler = logging.StreamHandler()
|
36 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
37 |
-
|
38 |
-
console_handler.setFormatter(console_formatter)
|
39 |
-
console_handler.setLevel(logging.INFO)
|
40 |
-
|
41 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
42 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
43 |
-
|
44 |
-
file_handler.setFormatter(file_formatter)
|
45 |
-
file_handler.setLevel(logging.DEBUG)
|
46 |
-
|
47 |
-
logger.addHandler(console_handler)
|
48 |
-
logger.addHandler(file_handler)
|
49 |
-
logger.setLevel(logging.DEBUG)
|
50 |
-
|
51 |
-
|
52 |
-
def parse_arguments() -> tuple:
|
53 |
-
parser = argparse.ArgumentParser()
|
54 |
-
parser.add_argument("--input_audio", type=str, required=True)
|
55 |
-
parser.add_argument("--output_dataset", type=str, default="./dataset")
|
56 |
-
parser.add_argument("--resample", type=lambda x: bool(strtobool(x)), default=False)
|
57 |
-
parser.add_argument("--resample_sr", type=int, default=44100)
|
58 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
59 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
60 |
-
parser.add_argument("--separator_music", type=lambda x: bool(strtobool(x)), default=False)
|
61 |
-
parser.add_argument("--separator_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
62 |
-
parser.add_argument("--kim_vocal_version", type=int, default=2)
|
63 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
64 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
65 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
66 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
67 |
-
parser.add_argument("--denoise_mdx", type=lambda x: bool(strtobool(x)), default=False)
|
68 |
-
parser.add_argument("--skip", type=lambda x: bool(strtobool(x)), default=False)
|
69 |
-
parser.add_argument("--skip_start_audios", type=str, default="0")
|
70 |
-
parser.add_argument("--skip_end_audios", type=str, default="0")
|
71 |
-
|
72 |
-
args = parser.parse_args()
|
73 |
-
return args
|
74 |
-
|
75 |
-
|
76 |
-
dataset_temp = os.path.join("dataset_temp")
|
77 |
-
|
78 |
-
|
79 |
-
def main():
|
80 |
-
args = parse_arguments()
|
81 |
-
input_audio = args.input_audio
|
82 |
-
output_dataset = args.output_dataset
|
83 |
-
resample = args.resample
|
84 |
-
resample_sr = args.resample_sr
|
85 |
-
clean_dataset = args.clean_dataset
|
86 |
-
clean_strength = args.clean_strength
|
87 |
-
separator_music = args.separator_music
|
88 |
-
separator_reverb = args.separator_reverb
|
89 |
-
kim_vocal_version = args.kim_vocal_version
|
90 |
-
overlap = args.overlap
|
91 |
-
segments_size = args.segments_size
|
92 |
-
hop_length = args.mdx_hop_length
|
93 |
-
batch_size = args.mdx_batch_size
|
94 |
-
denoise_mdx = args.denoise_mdx
|
95 |
-
skip = args.skip
|
96 |
-
skip_start_audios = args.skip_start_audios
|
97 |
-
skip_end_audios = args.skip_end_audios
|
98 |
-
|
99 |
-
logger.debug(f"{translations['audio_path']}: {input_audio}")
|
100 |
-
logger.debug(f"{translations['output_path']}: {output_dataset}")
|
101 |
-
logger.debug(f"{translations['resample']}: {resample}")
|
102 |
-
if resample: logger.debug(f"{translations['sample_rate']}: {resample_sr}")
|
103 |
-
logger.debug(f"{translations['clear_dataset']}: {clean_dataset}")
|
104 |
-
if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
|
105 |
-
logger.debug(f"{translations['separator_music']}: {separator_music}")
|
106 |
-
logger.debug(f"{translations['dereveb_audio']}: {separator_reverb}")
|
107 |
-
if separator_music: logger.debug(f"{translations['training_version']}: {kim_vocal_version}")
|
108 |
-
logger.debug(f"{translations['segments_size']}: {segments_size}")
|
109 |
-
logger.debug(f"{translations['overlap']}: {overlap}")
|
110 |
-
logger.debug(f"Hop length: {hop_length}")
|
111 |
-
logger.debug(f"{translations['batch_size']}: {batch_size}")
|
112 |
-
logger.debug(f"{translations['denoise_mdx']}: {denoise_mdx}")
|
113 |
-
logger.debug(f"{translations['skip']}: {skip}")
|
114 |
-
if skip: logger.debug(f"{translations['skip_start']}: {skip_start_audios}")
|
115 |
-
if skip: logger.debug(f"{translations['skip_end']}: {skip_end_audios}")
|
116 |
-
|
117 |
-
|
118 |
-
if kim_vocal_version != 1 and kim_vocal_version != 2: raise ValueError(translations["version_not_valid"])
|
119 |
-
if separator_reverb and not separator_music: raise ValueError(translations["create_dataset_value_not_valid"])
|
120 |
-
|
121 |
-
start_time = time.time()
|
122 |
-
|
123 |
-
|
124 |
-
try:
|
125 |
-
paths = []
|
126 |
-
|
127 |
-
if not os.path.exists(dataset_temp): os.makedirs(dataset_temp, exist_ok=True)
|
128 |
-
|
129 |
-
urls = input_audio.replace(", ", ",").split(",")
|
130 |
-
|
131 |
-
for url in urls:
|
132 |
-
path = downloader(url, urls.index(url))
|
133 |
-
paths.append(path)
|
134 |
-
|
135 |
-
if skip:
|
136 |
-
skip_start_audios = skip_start_audios.replace(", ", ",").split(",")
|
137 |
-
skip_end_audios = skip_end_audios.replace(", ", ",").split(",")
|
138 |
-
|
139 |
-
if len(skip_start_audios) < len(paths) or len(skip_end_audios) < len(paths):
|
140 |
-
logger.warning(translations["skip<audio"])
|
141 |
-
sys.exit(1)
|
142 |
-
elif len(skip_start_audios) > len(paths) or len(skip_end_audios) > len(paths):
|
143 |
-
logger.warning(translations["skip>audio"])
|
144 |
-
sys.exit(1)
|
145 |
-
else:
|
146 |
-
for audio, skip_start_audio, skip_end_audio in zip(paths, skip_start_audios, skip_end_audios):
|
147 |
-
skip_start(audio, skip_start_audio)
|
148 |
-
skip_end(audio, skip_end_audio)
|
149 |
-
|
150 |
-
if separator_music:
|
151 |
-
separator_paths = []
|
152 |
-
|
153 |
-
for audio in paths:
|
154 |
-
vocals = separator_music_main(audio, dataset_temp, segments_size, overlap, denoise_mdx, kim_vocal_version, hop_length, batch_size)
|
155 |
-
|
156 |
-
if separator_reverb: vocals = separator_reverb_audio(vocals, dataset_temp, segments_size, overlap, denoise_mdx, hop_length, batch_size)
|
157 |
-
separator_paths.append(vocals)
|
158 |
-
|
159 |
-
paths = separator_paths
|
160 |
-
|
161 |
-
processed_paths = []
|
162 |
-
|
163 |
-
for audio in paths:
|
164 |
-
output = process_audio(audio)
|
165 |
-
processed_paths.append(output)
|
166 |
-
|
167 |
-
paths = processed_paths
|
168 |
-
|
169 |
-
for audio_path in paths:
|
170 |
-
data, sample_rate = sf.read(audio_path)
|
171 |
-
|
172 |
-
if resample_sr != sample_rate and resample_sr > 0 and resample:
|
173 |
-
data = librosa.resample(data, orig_sr=sample_rate, target_sr=resample_sr)
|
174 |
-
sample_rate = resample_sr
|
175 |
-
|
176 |
-
if clean_dataset: data = nr.reduce_noise(y=data, prop_decrease=clean_strength)
|
177 |
-
|
178 |
-
|
179 |
-
sf.write(audio_path, data, sample_rate)
|
180 |
-
except Exception as e:
|
181 |
-
raise RuntimeError(f"{translations['create_dataset_error']}: {e}")
|
182 |
-
finally:
|
183 |
-
for audio in paths:
|
184 |
-
shutil.move(audio, output_dataset)
|
185 |
-
|
186 |
-
if os.path.exists(dataset_temp): shutil.rmtree(dataset_temp, ignore_errors=True)
|
187 |
-
|
188 |
-
|
189 |
-
elapsed_time = time.time() - start_time
|
190 |
-
logger.info(translations["create_dataset_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
191 |
-
|
192 |
-
|
193 |
-
def downloader(url, name):
|
194 |
-
with warnings.catch_warnings():
|
195 |
-
warnings.simplefilter("ignore")
|
196 |
-
|
197 |
-
ydl_opts = {
|
198 |
-
'format': 'bestaudio/best',
|
199 |
-
'outtmpl': os.path.join(dataset_temp, f"{name}"),
|
200 |
-
'postprocessors': [{
|
201 |
-
'key': 'FFmpegExtractAudio',
|
202 |
-
'preferredcodec': 'wav',
|
203 |
-
'preferredquality': '192',
|
204 |
-
}],
|
205 |
-
'noplaylist': True,
|
206 |
-
'verbose': False,
|
207 |
-
}
|
208 |
-
|
209 |
-
logger.info(f"{translations['starting_download']}: {url}...")
|
210 |
-
with yt_dlp.YoutubeDL(ydl_opts) as ydl:
|
211 |
-
ydl.extract_info(url)
|
212 |
-
logger.info(f"{translations['download_success']}: {url}")
|
213 |
-
|
214 |
-
return os.path.join(dataset_temp, f"{name}" + ".wav")
|
215 |
-
|
216 |
-
|
217 |
-
def skip_start(input_file, seconds):
|
218 |
-
data, sr = sf.read(input_file)
|
219 |
-
|
220 |
-
total_duration = len(data) / sr
|
221 |
-
|
222 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
223 |
-
elif seconds >= total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
224 |
-
else:
|
225 |
-
logger.info(f"{translations['skip_start']}: {input_file}...")
|
226 |
-
|
227 |
-
sf.write(input_file, data[int(seconds * sr):], sr)
|
228 |
-
|
229 |
-
logger.info(translations["skip_start_audio"].format(input_file=input_file))
|
230 |
-
|
231 |
-
|
232 |
-
def skip_end(input_file, seconds):
|
233 |
-
data, sr = sf.read(input_file)
|
234 |
-
|
235 |
-
total_duration = len(data) / sr
|
236 |
-
|
237 |
-
if seconds <= 0: logger.warning(translations["=<0"])
|
238 |
-
elif seconds > total_duration: logger.warning(translations["skip_warning"].format(seconds=seconds, total_duration=f"{total_duration:.2f}"))
|
239 |
-
else:
|
240 |
-
logger.info(f"{translations['skip_end']}: {input_file}...")
|
241 |
-
|
242 |
-
sf.write(input_file, data[:-int(seconds * sr)], sr)
|
243 |
-
|
244 |
-
logger.info(translations["skip_end_audio"].format(input_file=input_file))
|
245 |
-
|
246 |
-
|
247 |
-
def process_audio(file_path):
|
248 |
-
try:
|
249 |
-
song = AudioSegment.from_file(file_path)
|
250 |
-
nonsilent_parts = silence.detect_nonsilent(song, min_silence_len=750, silence_thresh=-70)
|
251 |
-
|
252 |
-
cut_files = []
|
253 |
-
|
254 |
-
for i, (start_i, end_i) in enumerate(nonsilent_parts):
|
255 |
-
chunk = song[start_i:end_i]
|
256 |
-
|
257 |
-
if len(chunk) >= 30:
|
258 |
-
chunk_file_path = os.path.join(os.path.dirname(file_path), f"chunk{i}.wav")
|
259 |
-
if os.path.exists(chunk_file_path): os.remove(chunk_file_path)
|
260 |
-
|
261 |
-
chunk.export(chunk_file_path, format="wav")
|
262 |
-
|
263 |
-
cut_files.append(chunk_file_path)
|
264 |
-
else: logger.warning(translations["skip_file"].format(i=i, chunk=len(chunk)))
|
265 |
-
|
266 |
-
logger.info(f"{translations['split_total']}: {len(cut_files)}")
|
267 |
-
|
268 |
-
def extract_number(filename):
|
269 |
-
match = re.search(r'_(\d+)', filename)
|
270 |
-
|
271 |
-
return int(match.group(1)) if match else 0
|
272 |
-
|
273 |
-
cut_files = sorted(cut_files, key=extract_number)
|
274 |
-
|
275 |
-
combined = AudioSegment.empty()
|
276 |
-
|
277 |
-
for file in cut_files:
|
278 |
-
combined += AudioSegment.from_file(file)
|
279 |
-
|
280 |
-
output_path = os.path.splitext(file_path)[0] + "_processed" + ".wav"
|
281 |
-
|
282 |
-
logger.info(translations["merge_audio"])
|
283 |
-
|
284 |
-
combined.export(output_path, format="wav")
|
285 |
-
|
286 |
-
return output_path
|
287 |
-
except Exception as e:
|
288 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
289 |
-
|
290 |
-
|
291 |
-
def separator_music_main(input, output, segments_size, overlap, denoise, version, hop_length, batch_size):
|
292 |
-
if not os.path.exists(input):
|
293 |
-
logger.warning(translations["input_not_valid"])
|
294 |
-
return None
|
295 |
-
|
296 |
-
if not os.path.exists(output):
|
297 |
-
logger.warning(translations["output_not_valid"])
|
298 |
-
return None
|
299 |
-
|
300 |
-
model = f"Kim_Vocal_{version}.onnx"
|
301 |
-
|
302 |
-
logger.info(translations["separator_process"].format(input=input))
|
303 |
-
output_separator = separator_main(audio_file=input, model_filename=model, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
|
304 |
-
|
305 |
-
for f in output_separator:
|
306 |
-
path = os.path.join(output, f)
|
307 |
-
|
308 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
309 |
-
|
310 |
-
if '_(Instrumental)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
311 |
-
elif '_(Vocals)_' in f:
|
312 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
313 |
-
os.rename(path, rename_file)
|
314 |
-
|
315 |
-
logger.info(f": {rename_file}")
|
316 |
-
return rename_file
|
317 |
-
|
318 |
-
|
319 |
-
def separator_reverb_audio(input, output, segments_size, overlap, denoise, hop_length, batch_size):
|
320 |
-
reverb_models = "Reverb_HQ_By_FoxJoy.onnx"
|
321 |
-
|
322 |
-
if not os.path.exists(input):
|
323 |
-
logger.warning(translations["input_not_valid"])
|
324 |
-
return None
|
325 |
-
|
326 |
-
if not os.path.exists(output):
|
327 |
-
logger.warning(translations["output_not_valid"])
|
328 |
-
return None
|
329 |
-
|
330 |
-
logger.info(f"{translations['dereverb']}: {input}...")
|
331 |
-
output_dereverb = separator_main(audio_file=input, model_filename=reverb_models, output_format="wav", output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=hop_length, mdx_hop_length=batch_size, mdx_enable_denoise=denoise)
|
332 |
-
|
333 |
-
for f in output_dereverb:
|
334 |
-
path = os.path.join(output, f)
|
335 |
-
|
336 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
337 |
-
|
338 |
-
if '_(Reverb)_' in f: os.rename(path, os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav")
|
339 |
-
elif '_(No Reverb)_' in f:
|
340 |
-
rename_file = os.path.splitext(path)[0].replace("(", "").replace(")", "") + ".wav"
|
341 |
-
os.rename(path, rename_file)
|
342 |
-
|
343 |
-
logger.info(f"{translations['dereverb_success']}: {rename_file}")
|
344 |
-
return rename_file
|
345 |
-
|
346 |
-
|
347 |
-
def separator_main(audio_file=None, model_filename="Kim_Vocal_1.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True):
|
348 |
-
separator = Separator(
|
349 |
-
log_formatter=file_formatter,
|
350 |
-
log_level=logging.INFO,
|
351 |
-
output_dir=output_dir,
|
352 |
-
output_format=output_format,
|
353 |
-
output_bitrate=None,
|
354 |
-
normalization_threshold=0.9,
|
355 |
-
output_single_stem=None,
|
356 |
-
invert_using_spec=False,
|
357 |
-
sample_rate=44100,
|
358 |
-
mdx_params={
|
359 |
-
"hop_length": mdx_hop_length,
|
360 |
-
"segment_size": mdx_segment_size,
|
361 |
-
"overlap": mdx_overlap,
|
362 |
-
"batch_size": mdx_batch_size,
|
363 |
-
"enable_denoise": mdx_enable_denoise,
|
364 |
-
},
|
365 |
-
)
|
366 |
-
|
367 |
-
separator.load_model(model_filename=model_filename)
|
368 |
-
return separator.separate(audio_file)
|
369 |
-
|
370 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/create_index.py
DELETED
@@ -1,120 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import faiss
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
|
10 |
-
from multiprocessing import cpu_count
|
11 |
-
from sklearn.cluster import MiniBatchKMeans
|
12 |
-
|
13 |
-
|
14 |
-
now_dir = os.getcwd()
|
15 |
-
sys.path.append(now_dir)
|
16 |
-
|
17 |
-
from main.configs.config import Config
|
18 |
-
|
19 |
-
translations = Config().translations
|
20 |
-
|
21 |
-
|
22 |
-
def parse_arguments() -> tuple:
|
23 |
-
parser = argparse.ArgumentParser()
|
24 |
-
parser.add_argument("--model_name", type=str, required=True)
|
25 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
26 |
-
parser.add_argument("--index_algorithm", type=str, default="Auto")
|
27 |
-
|
28 |
-
args = parser.parse_args()
|
29 |
-
return args
|
30 |
-
|
31 |
-
|
32 |
-
def main():
|
33 |
-
args = parse_arguments()
|
34 |
-
|
35 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
36 |
-
version = args.rvc_version
|
37 |
-
index_algorithm = args.index_algorithm
|
38 |
-
|
39 |
-
log_file = os.path.join(exp_dir, "create_index.log")
|
40 |
-
logger = logging.getLogger(__name__)
|
41 |
-
|
42 |
-
|
43 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
44 |
-
else:
|
45 |
-
console_handler = logging.StreamHandler()
|
46 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
47 |
-
|
48 |
-
console_handler.setFormatter(console_formatter)
|
49 |
-
console_handler.setLevel(logging.INFO)
|
50 |
-
|
51 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
52 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
53 |
-
|
54 |
-
file_handler.setFormatter(file_formatter)
|
55 |
-
file_handler.setLevel(logging.DEBUG)
|
56 |
-
|
57 |
-
logger.addHandler(console_handler)
|
58 |
-
logger.addHandler(file_handler)
|
59 |
-
logger.setLevel(logging.DEBUG)
|
60 |
-
|
61 |
-
logger.debug(f"{translations['modelname']}: {args.model_name}")
|
62 |
-
logger.debug(f"{translations['model_path']}: {exp_dir}")
|
63 |
-
logger.debug(f"{translations['training_version']}: {version}")
|
64 |
-
logger.debug(f"{translations['index_algorithm_info']}: {index_algorithm}")
|
65 |
-
|
66 |
-
|
67 |
-
try:
|
68 |
-
feature_dir = os.path.join(exp_dir, f"{version}_extracted")
|
69 |
-
model_name = os.path.basename(exp_dir)
|
70 |
-
|
71 |
-
npys = []
|
72 |
-
listdir_res = sorted(os.listdir(feature_dir))
|
73 |
-
|
74 |
-
for name in listdir_res:
|
75 |
-
file_path = os.path.join(feature_dir, name)
|
76 |
-
phone = np.load(file_path)
|
77 |
-
npys.append(phone)
|
78 |
-
|
79 |
-
big_npy = np.concatenate(npys, axis=0)
|
80 |
-
big_npy_idx = np.arange(big_npy.shape[0])
|
81 |
-
|
82 |
-
np.random.shuffle(big_npy_idx)
|
83 |
-
|
84 |
-
big_npy = big_npy[big_npy_idx]
|
85 |
-
|
86 |
-
if big_npy.shape[0] > 2e5 and (index_algorithm == "Auto" or index_algorithm == "KMeans"): big_npy = (MiniBatchKMeans(n_clusters=10000, verbose=True, batch_size=256 * cpu_count(), compute_labels=False, init="random").fit(big_npy).cluster_centers_)
|
87 |
-
|
88 |
-
np.save(os.path.join(exp_dir, "total_fea.npy"), big_npy)
|
89 |
-
|
90 |
-
n_ivf = min(int(16 * np.sqrt(big_npy.shape[0])), big_npy.shape[0] // 39)
|
91 |
-
|
92 |
-
index_trained = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
93 |
-
|
94 |
-
index_ivf_trained = faiss.extract_index_ivf(index_trained)
|
95 |
-
index_ivf_trained.nprobe = 1
|
96 |
-
|
97 |
-
index_trained.train(big_npy)
|
98 |
-
|
99 |
-
faiss.write_index(index_trained, os.path.join(exp_dir, f"trained_IVF{n_ivf}_Flat_nprobe_{index_ivf_trained.nprobe}_{model_name}_{version}.index"))
|
100 |
-
|
101 |
-
index_added = faiss.index_factory(256 if version == "v1" else 768, f"IVF{n_ivf},Flat")
|
102 |
-
index_ivf_added = faiss.extract_index_ivf(index_added)
|
103 |
-
|
104 |
-
index_ivf_added.nprobe = 1
|
105 |
-
index_added.train(big_npy)
|
106 |
-
|
107 |
-
batch_size_add = 8192
|
108 |
-
|
109 |
-
for i in range(0, big_npy.shape[0], batch_size_add):
|
110 |
-
index_added.add(big_npy[i : i + batch_size_add])
|
111 |
-
|
112 |
-
index_filepath_added = os.path.join(exp_dir, f"added_IVF{n_ivf}_Flat_nprobe_{index_ivf_added.nprobe}_{model_name}_{version}.index")
|
113 |
-
|
114 |
-
faiss.write_index(index_added, index_filepath_added)
|
115 |
-
|
116 |
-
logger.info(f"{translations['save_index']} '{index_filepath_added}'")
|
117 |
-
except Exception as e:
|
118 |
-
logger.error(f"{translations['create_index_error']}: {e}")
|
119 |
-
|
120 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/extract.py
DELETED
@@ -1,450 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import gc
|
3 |
-
import sys
|
4 |
-
import time
|
5 |
-
import tqdm
|
6 |
-
import torch
|
7 |
-
import shutil
|
8 |
-
import codecs
|
9 |
-
import pyworld
|
10 |
-
import librosa
|
11 |
-
import logging
|
12 |
-
import argparse
|
13 |
-
import warnings
|
14 |
-
import subprocess
|
15 |
-
import torchcrepe
|
16 |
-
import parselmouth
|
17 |
-
import logging.handlers
|
18 |
-
|
19 |
-
import numpy as np
|
20 |
-
import soundfile as sf
|
21 |
-
import torch.nn.functional as F
|
22 |
-
|
23 |
-
from random import shuffle
|
24 |
-
from functools import partial
|
25 |
-
from multiprocessing import Pool
|
26 |
-
from distutils.util import strtobool
|
27 |
-
from fairseq import checkpoint_utils
|
28 |
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
29 |
-
|
30 |
-
now_dir = os.getcwd()
|
31 |
-
sys.path.append(now_dir)
|
32 |
-
|
33 |
-
from main.configs.config import Config
|
34 |
-
from main.library.predictors.FCPE import FCPE
|
35 |
-
from main.library.predictors.RMVPE import RMVPE
|
36 |
-
|
37 |
-
logging.getLogger("wget").setLevel(logging.ERROR)
|
38 |
-
|
39 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
40 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
41 |
-
|
42 |
-
logger = logging.getLogger(__name__)
|
43 |
-
logger.propagate = False
|
44 |
-
|
45 |
-
config = Config()
|
46 |
-
translations = config.translations
|
47 |
-
|
48 |
-
def parse_arguments() -> tuple:
|
49 |
-
parser = argparse.ArgumentParser()
|
50 |
-
parser.add_argument("--model_name", type=str, required=True)
|
51 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
52 |
-
parser.add_argument("--f0_method", type=str, default="rmvpe")
|
53 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
54 |
-
parser.add_argument("--hop_length", type=int, default=128)
|
55 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
56 |
-
parser.add_argument("--gpu", type=str, default="-")
|
57 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
58 |
-
parser.add_argument("--embedder_model", type=str, default="contentvec_base")
|
59 |
-
|
60 |
-
args = parser.parse_args()
|
61 |
-
return args
|
62 |
-
|
63 |
-
def load_audio(file, sample_rate):
|
64 |
-
try:
|
65 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
66 |
-
audio, sr = sf.read(file)
|
67 |
-
|
68 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
69 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
|
70 |
-
except Exception as e:
|
71 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
72 |
-
|
73 |
-
return audio.flatten()
|
74 |
-
|
75 |
-
def check_rmvpe_fcpe(method):
|
76 |
-
if method == "rmvpe" and not os.path.exists(os.path.join("assets", "model", "predictors", "rmvpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "rmvpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
|
77 |
-
elif method == "fcpe" and not os.path.exists(os.path.join("assets", "model", "predictors", "fcpe.pt")): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + "fcpe.pt", "-P", os.path.join("assets", "model", "predictors")], check=True)
|
78 |
-
|
79 |
-
def check_hubert(hubert):
|
80 |
-
if hubert == "contentvec_base" or hubert == "hubert_base" or hubert == "japanese_hubert_base" or hubert == "korean_hubert_base" or hubert == "chinese_hubert_base":
|
81 |
-
model_path = os.path.join(now_dir, "assets", "model", "embedders", hubert + '.pt')
|
82 |
-
|
83 |
-
if not os.path.exists(model_path): subprocess.run(["wget", "-q", "--show-progress", "--no-check-certificate", codecs.decode("uggcf://uhttvatsnpr.pb/NauC/Pbyno_EIP_Cebwrpg_2/erfbyir/znva/", "rot13") + f"{hubert}.pt", "-P", os.path.join("assets", "model", "embedders")], check=True)
|
84 |
-
|
85 |
-
def generate_config(rvc_version, sample_rate, model_path):
|
86 |
-
config_path = os.path.join("main", "configs", rvc_version, f"{sample_rate}.json")
|
87 |
-
config_save_path = os.path.join(model_path, "config.json")
|
88 |
-
if not os.path.exists(config_save_path): shutil.copy(config_path, config_save_path)
|
89 |
-
|
90 |
-
|
91 |
-
def generate_filelist(pitch_guidance, model_path, rvc_version, sample_rate):
|
92 |
-
gt_wavs_dir = os.path.join(model_path, "sliced_audios")
|
93 |
-
feature_dir = os.path.join(model_path, f"{rvc_version}_extracted")
|
94 |
-
|
95 |
-
f0_dir, f0nsf_dir = None, None
|
96 |
-
|
97 |
-
if pitch_guidance:
|
98 |
-
f0_dir = os.path.join(model_path, "f0")
|
99 |
-
f0nsf_dir = os.path.join(model_path, "f0_voiced")
|
100 |
-
|
101 |
-
gt_wavs_files = set(name.split(".")[0] for name in os.listdir(gt_wavs_dir))
|
102 |
-
feature_files = set(name.split(".")[0] for name in os.listdir(feature_dir))
|
103 |
-
|
104 |
-
if pitch_guidance:
|
105 |
-
f0_files = set(name.split(".")[0] for name in os.listdir(f0_dir))
|
106 |
-
f0nsf_files = set(name.split(".")[0] for name in os.listdir(f0nsf_dir))
|
107 |
-
names = gt_wavs_files & feature_files & f0_files & f0nsf_files
|
108 |
-
else: names = gt_wavs_files & feature_files
|
109 |
-
|
110 |
-
options = []
|
111 |
-
mute_base_path = os.path.join(now_dir, "assets", "logs", "mute")
|
112 |
-
|
113 |
-
for name in names:
|
114 |
-
if pitch_guidance: options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|{f0_dir}/{name}.wav.npy|{f0nsf_dir}/{name}.wav.npy|0")
|
115 |
-
else: options.append(f"{gt_wavs_dir}/{name}.wav|{feature_dir}/{name}.npy|0")
|
116 |
-
|
117 |
-
mute_audio_path = os.path.join(mute_base_path, "sliced_audios", f"mute{sample_rate}.wav")
|
118 |
-
mute_feature_path = os.path.join(mute_base_path, f"{rvc_version}_extracted", "mute.npy")
|
119 |
-
|
120 |
-
for _ in range(2):
|
121 |
-
if pitch_guidance:
|
122 |
-
mute_f0_path = os.path.join(mute_base_path, "f0", "mute.wav.npy")
|
123 |
-
mute_f0nsf_path = os.path.join(mute_base_path, "f0_voiced", "mute.wav.npy")
|
124 |
-
options.append(f"{mute_audio_path}|{mute_feature_path}|{mute_f0_path}|{mute_f0nsf_path}|0")
|
125 |
-
else: options.append(f"{mute_audio_path}|{mute_feature_path}|0")
|
126 |
-
|
127 |
-
shuffle(options)
|
128 |
-
|
129 |
-
with open(os.path.join(model_path, "filelist.txt"), "w") as f:
|
130 |
-
f.write("\n".join(options))
|
131 |
-
|
132 |
-
def setup_paths(exp_dir, version = None):
|
133 |
-
wav_path = os.path.join(exp_dir, "sliced_audios_16k")
|
134 |
-
if version:
|
135 |
-
out_path = os.path.join(exp_dir, "v1_extracted" if version == "v1" else "v2_extracted")
|
136 |
-
os.makedirs(out_path, exist_ok=True)
|
137 |
-
|
138 |
-
return wav_path, out_path
|
139 |
-
else:
|
140 |
-
output_root1 = os.path.join(exp_dir, "f0")
|
141 |
-
output_root2 = os.path.join(exp_dir, "f0_voiced")
|
142 |
-
|
143 |
-
os.makedirs(output_root1, exist_ok=True)
|
144 |
-
os.makedirs(output_root2, exist_ok=True)
|
145 |
-
|
146 |
-
return wav_path, output_root1, output_root2
|
147 |
-
|
148 |
-
def read_wave(wav_path, normalize = False):
|
149 |
-
wav, sr = sf.read(wav_path)
|
150 |
-
assert sr == 16000, translations["sr_not_16000"]
|
151 |
-
|
152 |
-
feats = torch.from_numpy(wav).float()
|
153 |
-
|
154 |
-
if config.is_half: feats = feats.half()
|
155 |
-
if feats.dim() == 2: feats = feats.mean(-1)
|
156 |
-
|
157 |
-
feats = feats.view(1, -1)
|
158 |
-
|
159 |
-
if normalize: feats = F.layer_norm(feats, feats.shape)
|
160 |
-
|
161 |
-
return feats
|
162 |
-
|
163 |
-
def get_device(gpu_index):
|
164 |
-
if gpu_index == "cpu": return "cpu"
|
165 |
-
|
166 |
-
try:
|
167 |
-
index = int(gpu_index)
|
168 |
-
if index < torch.cuda.device_count(): return f"cuda:{index}"
|
169 |
-
else: logger.warning(translations["gpu_not_valid"])
|
170 |
-
except ValueError:
|
171 |
-
logger.warning(translations["gpu_not_valid"])
|
172 |
-
|
173 |
-
return "cpu"
|
174 |
-
|
175 |
-
class FeatureInput:
|
176 |
-
def __init__(self, sample_rate=16000, hop_size=160, device="cpu"):
|
177 |
-
self.fs = sample_rate
|
178 |
-
self.hop = hop_size
|
179 |
-
self.f0_bin = 256
|
180 |
-
self.f0_max = 1100.0
|
181 |
-
self.f0_min = 50.0
|
182 |
-
self.f0_mel_min = 1127 * np.log(1 + self.f0_min / 700)
|
183 |
-
self.f0_mel_max = 1127 * np.log(1 + self.f0_max / 700)
|
184 |
-
self.device = device
|
185 |
-
|
186 |
-
def compute_f0(self, np_arr, f0_method, hop_length):
|
187 |
-
if f0_method == "pm": return self.get_pm(np_arr)
|
188 |
-
elif f0_method == 'dio': return self.get_dio(np_arr)
|
189 |
-
elif f0_method == "crepe": return self.get_crepe(np_arr, int(hop_length))
|
190 |
-
elif f0_method == "crepe-tiny": return self.get_crepe(np_arr, int(hop_length), "tiny")
|
191 |
-
elif f0_method == "fcpe": return self.get_fcpe(np_arr, int(hop_length))
|
192 |
-
elif f0_method == "rmvpe": return self.get_rmvpe(np_arr)
|
193 |
-
elif f0_method == "harvest": return self.get_harvest(np_arr)
|
194 |
-
else: raise ValueError(translations["method_not_valid"])
|
195 |
-
|
196 |
-
def get_pm(self, x):
|
197 |
-
time_step = 160 / 16000 * 1000
|
198 |
-
f0 = (parselmouth.Sound(x, self.fs).to_pitch_ac(time_step=time_step / 1000, voicing_threshold=0.6, pitch_floor=50, pitch_ceiling=1100).selected_array["frequency"])
|
199 |
-
pad_size = ((x.size // self.hop) - len(f0) + 1) // 2
|
200 |
-
if pad_size > 0 or (x.size // self.hop) - len(f0) - pad_size > 0: f0 = np.pad(f0, [[pad_size, (x.size // self.hop) - len(f0) - pad_size]], mode="constant")
|
201 |
-
|
202 |
-
return f0
|
203 |
-
|
204 |
-
def get_dio(self, x):
|
205 |
-
f0, t = pyworld.dio(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
206 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
|
207 |
-
|
208 |
-
return f0
|
209 |
-
|
210 |
-
def get_crepe(self, x, hop_length, model="full"):
|
211 |
-
audio = torch.from_numpy(x.astype(np.float32)).to(self.device)
|
212 |
-
audio /= torch.quantile(torch.abs(audio), 0.999)
|
213 |
-
audio = audio.unsqueeze(0)
|
214 |
-
|
215 |
-
pitch = torchcrepe.predict(audio, self.fs, hop_length, self.f0_min, self.f0_max, model=model, batch_size=hop_length * 2, device=self.device, pad=True)
|
216 |
-
|
217 |
-
source = pitch.squeeze(0).cpu().float().numpy()
|
218 |
-
source[source < 0.001] = np.nan
|
219 |
-
target = np.interp(np.arange(0, len(source) * (x.size // self.hop), len(source)) / (x.size // self.hop), np.arange(0, len(source)), source)
|
220 |
-
|
221 |
-
return np.nan_to_num(target)
|
222 |
-
|
223 |
-
|
224 |
-
def get_fcpe(self, x, hop_length):
|
225 |
-
self.model_fcpe = FCPE(os.path.join("assets", "model", "predictors", "fcpe.pt"), hop_length=int(hop_length), f0_min=self.f0_min, f0_max=self.f0_max, dtype=torch.float32, device=self.device, sample_rate=self.fs, threshold=0.03)
|
226 |
-
f0 = self.model_fcpe.compute_f0(x, p_len=(x.size // self.hop))
|
227 |
-
del self.model_fcpe
|
228 |
-
gc.collect()
|
229 |
-
return f0
|
230 |
-
|
231 |
-
|
232 |
-
def get_rmvpe(self, x):
|
233 |
-
self.model_rmvpe = RMVPE(os.path.join("assets", "model", "predictors", "rmvpe.pt"), is_half=False, device=self.device)
|
234 |
-
return self.model_rmvpe.infer_from_audio(x, thred=0.03)
|
235 |
-
|
236 |
-
|
237 |
-
def get_harvest(self, x):
|
238 |
-
f0, t = pyworld.harvest(x.astype(np.double), fs=self.fs, f0_ceil=self.f0_max, f0_floor=self.f0_min, frame_period=1000 * self.hop / self.fs)
|
239 |
-
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.fs)
|
240 |
-
return f0
|
241 |
-
|
242 |
-
|
243 |
-
def coarse_f0(self, f0):
|
244 |
-
f0_mel = 1127 * np.log(1 + f0 / 700)
|
245 |
-
f0_mel = np.clip((f0_mel - self.f0_mel_min) * (self.f0_bin - 2) / (self.f0_mel_max - self.f0_mel_min) + 1, 1, self.f0_bin - 1)
|
246 |
-
return np.rint(f0_mel).astype(int)
|
247 |
-
|
248 |
-
|
249 |
-
def process_file(self, file_info, f0_method, hop_length):
|
250 |
-
inp_path, opt_path1, opt_path2, np_arr = file_info
|
251 |
-
|
252 |
-
if os.path.exists(opt_path1 + ".npy") and os.path.exists(opt_path2 + ".npy"): return
|
253 |
-
|
254 |
-
try:
|
255 |
-
feature_pit = self.compute_f0(np_arr, f0_method, hop_length)
|
256 |
-
np.save(opt_path2, feature_pit, allow_pickle=False)
|
257 |
-
coarse_pit = self.coarse_f0(feature_pit)
|
258 |
-
np.save(opt_path1, coarse_pit, allow_pickle=False)
|
259 |
-
except Exception as e:
|
260 |
-
raise RuntimeError(f"{translations['extract_file_error']} {inp_path}: {e}")
|
261 |
-
|
262 |
-
|
263 |
-
def process_files(self, files, f0_method, hop_length, pbar):
|
264 |
-
for file_info in files:
|
265 |
-
self.process_file(file_info, f0_method, hop_length)
|
266 |
-
pbar.update()
|
267 |
-
|
268 |
-
def run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus):
|
269 |
-
input_root, *output_roots = setup_paths(exp_dir)
|
270 |
-
|
271 |
-
if len(output_roots) == 2: output_root1, output_root2 = output_roots
|
272 |
-
else:
|
273 |
-
output_root1 = output_roots[0]
|
274 |
-
output_root2 = None
|
275 |
-
|
276 |
-
paths = [
|
277 |
-
(
|
278 |
-
os.path.join(input_root, name),
|
279 |
-
os.path.join(output_root1, name) if output_root1 else None,
|
280 |
-
os.path.join(output_root2, name) if output_root2 else None,
|
281 |
-
load_audio(os.path.join(input_root, name), 16000),
|
282 |
-
)
|
283 |
-
for name in sorted(os.listdir(input_root))
|
284 |
-
if "spec" not in name
|
285 |
-
]
|
286 |
-
|
287 |
-
logger.info(translations["extract_f0_method"].format(num_processes=num_processes, f0_method=f0_method))
|
288 |
-
start_time = time.time()
|
289 |
-
|
290 |
-
if gpus != "-":
|
291 |
-
gpus = gpus.split("-")
|
292 |
-
num_gpus = len(gpus)
|
293 |
-
|
294 |
-
process_partials = []
|
295 |
-
|
296 |
-
pbar = tqdm.tqdm(total=len(paths), desc=translations["extract_f0"])
|
297 |
-
|
298 |
-
for idx, gpu in enumerate(gpus):
|
299 |
-
device = get_device(gpu)
|
300 |
-
feature_input = FeatureInput(device=device)
|
301 |
-
|
302 |
-
part_paths = paths[idx::num_gpus]
|
303 |
-
process_partials.append((feature_input, part_paths))
|
304 |
-
|
305 |
-
with ThreadPoolExecutor() as executor:
|
306 |
-
futures = [executor.submit(FeatureInput.process_files, feature_input, part_paths, f0_method, hop_length, pbar) for feature_input, part_paths in process_partials]
|
307 |
-
|
308 |
-
for future in as_completed(futures):
|
309 |
-
pbar.update(1)
|
310 |
-
future.result()
|
311 |
-
|
312 |
-
pbar.close()
|
313 |
-
else:
|
314 |
-
feature_input = FeatureInput(device="cpu")
|
315 |
-
|
316 |
-
with tqdm.tqdm(total=len(paths), desc=translations["extract_f0"]) as pbar:
|
317 |
-
with Pool(processes=num_processes) as pool:
|
318 |
-
process_file_partial = partial(feature_input.process_file, f0_method=f0_method, hop_length=hop_length)
|
319 |
-
|
320 |
-
for _ in pool.imap_unordered(process_file_partial, paths):
|
321 |
-
pbar.update(1)
|
322 |
-
|
323 |
-
|
324 |
-
elapsed_time = time.time() - start_time
|
325 |
-
logger.info(translations["extract_f0_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
326 |
-
|
327 |
-
def process_file_embedding(file, wav_path, out_path, model, device, version, saved_cfg):
|
328 |
-
wav_file_path = os.path.join(wav_path, file)
|
329 |
-
out_file_path = os.path.join(out_path, file.replace("wav", "npy"))
|
330 |
-
|
331 |
-
if os.path.exists(out_file_path): return
|
332 |
-
|
333 |
-
feats = read_wave(wav_file_path, normalize=saved_cfg.task.normalize)
|
334 |
-
dtype = torch.float16 if device.startswith("cuda") else torch.float32
|
335 |
-
feats = feats.to(dtype).to(device)
|
336 |
-
|
337 |
-
padding_mask = torch.BoolTensor(feats.shape).fill_(False).to(dtype).to(device)
|
338 |
-
|
339 |
-
inputs = {
|
340 |
-
"source": feats,
|
341 |
-
"padding_mask": padding_mask,
|
342 |
-
"output_layer": 9 if version == "v1" else 12,
|
343 |
-
}
|
344 |
-
|
345 |
-
with torch.no_grad():
|
346 |
-
model = model.to(device).to(dtype)
|
347 |
-
logits = model.extract_features(**inputs)
|
348 |
-
feats = model.final_proj(logits[0]) if version == "v1" else logits[0]
|
349 |
-
|
350 |
-
feats = feats.squeeze(0).float().cpu().numpy()
|
351 |
-
|
352 |
-
if not np.isnan(feats).any(): np.save(out_file_path, feats, allow_pickle=False)
|
353 |
-
else: logger.warning(f"{file} {translations['NaN']}")
|
354 |
-
|
355 |
-
def run_embedding_extraction(exp_dir, version, gpus, embedder_model):
|
356 |
-
wav_path, out_path = setup_paths(exp_dir, version)
|
357 |
-
|
358 |
-
logger.info(translations["start_extract_hubert"])
|
359 |
-
start_time = time.time()
|
360 |
-
|
361 |
-
try:
|
362 |
-
models, saved_cfg, _ = checkpoint_utils.load_model_ensemble_and_task([os.path.join(now_dir, "assets", "model", "embedders", embedder_model + '.pt')], suffix="")
|
363 |
-
except Exception as e:
|
364 |
-
raise ImportError(translations["read_model_error"].format(e=e))
|
365 |
-
|
366 |
-
model = models[0]
|
367 |
-
devices = [get_device(gpu) for gpu in (gpus.split("-") if gpus != "-" else ["cpu"])]
|
368 |
-
|
369 |
-
paths = sorted([file for file in os.listdir(wav_path) if file.endswith(".wav")])
|
370 |
-
|
371 |
-
if not paths:
|
372 |
-
logger.warning(translations["not_found_audio_file"])
|
373 |
-
sys.exit(1)
|
374 |
-
|
375 |
-
pbar = tqdm.tqdm(total=len(paths) * len(devices), desc=translations["extract_hubert"])
|
376 |
-
|
377 |
-
tasks = [(file, wav_path, out_path, model, device, version, saved_cfg) for file in paths for device in devices]
|
378 |
-
|
379 |
-
for task in tasks:
|
380 |
-
try:
|
381 |
-
process_file_embedding(*task)
|
382 |
-
except Exception as e:
|
383 |
-
raise RuntimeError(f"{translations['process_error']} {task[0]}: {e}")
|
384 |
-
|
385 |
-
pbar.update(1)
|
386 |
-
|
387 |
-
pbar.close()
|
388 |
-
|
389 |
-
elapsed_time = time.time() - start_time
|
390 |
-
logger.info(translations["extract_hubert_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
391 |
-
|
392 |
-
if __name__ == "__main__":
|
393 |
-
args = parse_arguments()
|
394 |
-
|
395 |
-
exp_dir = os.path.join("assets", "logs", args.model_name)
|
396 |
-
f0_method = args.f0_method
|
397 |
-
hop_length = args.hop_length
|
398 |
-
num_processes = args.cpu_cores
|
399 |
-
gpus = args.gpu
|
400 |
-
version = args.rvc_version
|
401 |
-
pitch_guidance = args.pitch_guidance
|
402 |
-
sample_rate = args.sample_rate
|
403 |
-
embedder_model = args.embedder_model
|
404 |
-
|
405 |
-
check_rmvpe_fcpe(f0_method)
|
406 |
-
check_hubert(embedder_model)
|
407 |
-
|
408 |
-
if len([f for f in os.listdir(os.path.join(exp_dir, "sliced_audios")) if os.path.isfile(os.path.join(exp_dir, "sliced_audios", f))]) < 1 or len([f for f in os.listdir(os.path.join(exp_dir, "sliced_audios_16k")) if os.path.isfile(os.path.join(exp_dir, "sliced_audios_16k", f))]) < 1: raise FileNotFoundError("Không tìm thấy dữ liệu được xử lý, vui lòng xử lý lại âm thanh")
|
409 |
-
|
410 |
-
log_file = os.path.join(exp_dir, "extract.log")
|
411 |
-
|
412 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
413 |
-
else:
|
414 |
-
console_handler = logging.StreamHandler()
|
415 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
416 |
-
|
417 |
-
console_handler.setFormatter(console_formatter)
|
418 |
-
console_handler.setLevel(logging.INFO)
|
419 |
-
|
420 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
421 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
422 |
-
|
423 |
-
file_handler.setFormatter(file_formatter)
|
424 |
-
file_handler.setLevel(logging.DEBUG)
|
425 |
-
|
426 |
-
logger.addHandler(console_handler)
|
427 |
-
logger.addHandler(file_handler)
|
428 |
-
logger.setLevel(logging.DEBUG)
|
429 |
-
|
430 |
-
logger.debug(f"{translations['modelname']}: {args.model_name}")
|
431 |
-
logger.debug(f"{translations['export_process']}: {exp_dir}")
|
432 |
-
logger.debug(f"{translations['f0_method']}: {f0_method}")
|
433 |
-
logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
|
434 |
-
logger.debug(f"{translations['cpu_core']}: {num_processes}")
|
435 |
-
logger.debug(f"Gpu: {gpus}")
|
436 |
-
if f0_method == "crepe" or f0_method == "crepe-tiny" or f0_method == "fcpe": logger.debug(f"Hop length: {hop_length}")
|
437 |
-
logger.debug(f"{translations['training_version']}: {version}")
|
438 |
-
logger.debug(f"{translations['extract_f0']}: {pitch_guidance}")
|
439 |
-
logger.debug(f"{translations['hubert_model']}: {embedder_model}")
|
440 |
-
|
441 |
-
try:
|
442 |
-
run_pitch_extraction(exp_dir, f0_method, hop_length, num_processes, gpus)
|
443 |
-
run_embedding_extraction(exp_dir, version, gpus, embedder_model)
|
444 |
-
|
445 |
-
generate_config(version, sample_rate, exp_dir)
|
446 |
-
generate_filelist(pitch_guidance, exp_dir, version, sample_rate)
|
447 |
-
except Exception as e:
|
448 |
-
logger.error(f"{translations['extract_error']}: {e}")
|
449 |
-
|
450 |
-
logger.info(f"{translations['extract_success']} {args.model_name}.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/preprocess.py
DELETED
@@ -1,360 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import torch
|
5 |
-
import logging
|
6 |
-
import librosa
|
7 |
-
import argparse
|
8 |
-
import logging.handlers
|
9 |
-
|
10 |
-
import numpy as np
|
11 |
-
import soundfile as sf
|
12 |
-
import multiprocessing
|
13 |
-
import noisereduce as nr
|
14 |
-
|
15 |
-
from tqdm import tqdm
|
16 |
-
from scipy import signal
|
17 |
-
from scipy.io import wavfile
|
18 |
-
from distutils.util import strtobool
|
19 |
-
from concurrent.futures import ProcessPoolExecutor, as_completed
|
20 |
-
|
21 |
-
now_directory = os.getcwd()
|
22 |
-
sys.path.append(now_directory)
|
23 |
-
|
24 |
-
from main.configs.config import Config
|
25 |
-
|
26 |
-
logger = logging.getLogger(__name__)
|
27 |
-
|
28 |
-
|
29 |
-
logging.getLogger("numba.core.byteflow").setLevel(logging.ERROR)
|
30 |
-
logging.getLogger("numba.core.ssa").setLevel(logging.ERROR)
|
31 |
-
logging.getLogger("numba.core.interpreter").setLevel(logging.ERROR)
|
32 |
-
|
33 |
-
OVERLAP = 0.3
|
34 |
-
MAX_AMPLITUDE = 0.9
|
35 |
-
ALPHA = 0.75
|
36 |
-
HIGH_PASS_CUTOFF = 48
|
37 |
-
SAMPLE_RATE_16K = 16000
|
38 |
-
|
39 |
-
|
40 |
-
config = Config()
|
41 |
-
per = 3.0 if config.is_half else 3.7
|
42 |
-
translations = config.translations
|
43 |
-
|
44 |
-
|
45 |
-
def parse_arguments() -> tuple:
|
46 |
-
parser = argparse.ArgumentParser()
|
47 |
-
parser.add_argument("--model_name", type=str, required=True)
|
48 |
-
parser.add_argument("--dataset_path", type=str, default="./dataset")
|
49 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
50 |
-
parser.add_argument("--cpu_cores", type=int, default=2)
|
51 |
-
parser.add_argument("--cut_preprocess", type=lambda x: bool(strtobool(x)), default=True)
|
52 |
-
parser.add_argument("--process_effects", type=lambda x: bool(strtobool(x)), default=False)
|
53 |
-
parser.add_argument("--clean_dataset", type=lambda x: bool(strtobool(x)), default=False)
|
54 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
55 |
-
|
56 |
-
args = parser.parse_args()
|
57 |
-
return args
|
58 |
-
|
59 |
-
|
60 |
-
def load_audio(file, sample_rate):
|
61 |
-
try:
|
62 |
-
file = file.strip(" ").strip('"').strip("\n").strip('"').strip(" ")
|
63 |
-
audio, sr = sf.read(file)
|
64 |
-
|
65 |
-
if len(audio.shape) > 1: audio = librosa.to_mono(audio.T)
|
66 |
-
if sr != sample_rate: audio = librosa.resample(audio, orig_sr=sr, target_sr=sample_rate)
|
67 |
-
except Exception as e:
|
68 |
-
raise RuntimeError(f"{translations['errors_loading_audio']}: {e}")
|
69 |
-
|
70 |
-
return audio.flatten()
|
71 |
-
|
72 |
-
class Slicer:
|
73 |
-
def __init__(self, sr, threshold = -40.0, min_length = 5000, min_interval = 300, hop_size = 20, max_sil_kept = 5000):
|
74 |
-
if not min_length >= min_interval >= hop_size: raise ValueError(translations["min_length>=min_interval>=hop_size"])
|
75 |
-
if not max_sil_kept >= hop_size: raise ValueError(translations["max_sil_kept>=hop_size"])
|
76 |
-
|
77 |
-
min_interval = sr * min_interval / 1000
|
78 |
-
self.threshold = 10 ** (threshold / 20.0)
|
79 |
-
self.hop_size = round(sr * hop_size / 1000)
|
80 |
-
self.win_size = min(round(min_interval), 4 * self.hop_size)
|
81 |
-
self.min_length = round(sr * min_length / 1000 / self.hop_size)
|
82 |
-
self.min_interval = round(min_interval / self.hop_size)
|
83 |
-
self.max_sil_kept = round(sr * max_sil_kept / 1000 / self.hop_size)
|
84 |
-
|
85 |
-
def _apply_slice(self, waveform, begin, end):
|
86 |
-
start_idx = begin * self.hop_size
|
87 |
-
|
88 |
-
if len(waveform.shape) > 1:
|
89 |
-
end_idx = min(waveform.shape[1], end * self.hop_size)
|
90 |
-
return waveform[:, start_idx:end_idx]
|
91 |
-
else:
|
92 |
-
end_idx = min(waveform.shape[0], end * self.hop_size)
|
93 |
-
return waveform[start_idx:end_idx]
|
94 |
-
|
95 |
-
def slice(self, waveform):
|
96 |
-
samples = waveform.mean(axis=0) if len(waveform.shape) > 1 else waveform
|
97 |
-
if samples.shape[0] <= self.min_length: return [waveform]
|
98 |
-
|
99 |
-
rms_list = get_rms(y=samples, frame_length=self.win_size, hop_length=self.hop_size).squeeze(0)
|
100 |
-
|
101 |
-
sil_tags = []
|
102 |
-
silence_start, clip_start = None, 0
|
103 |
-
|
104 |
-
for i, rms in enumerate(rms_list):
|
105 |
-
if rms < self.threshold:
|
106 |
-
if silence_start is None: silence_start = i
|
107 |
-
continue
|
108 |
-
|
109 |
-
if silence_start is None: continue
|
110 |
-
|
111 |
-
is_leading_silence = silence_start == 0 and i > self.max_sil_kept
|
112 |
-
need_slice_middle = (i - silence_start >= self.min_interval and i - clip_start >= self.min_length)
|
113 |
-
|
114 |
-
if not is_leading_silence and not need_slice_middle:
|
115 |
-
silence_start = None
|
116 |
-
continue
|
117 |
-
|
118 |
-
if i - silence_start <= self.max_sil_kept:
|
119 |
-
pos = rms_list[silence_start : i + 1].argmin() + silence_start
|
120 |
-
if silence_start == 0: sil_tags.append((0, pos))
|
121 |
-
else: sil_tags.append((pos, pos))
|
122 |
-
|
123 |
-
clip_start = pos
|
124 |
-
|
125 |
-
elif i - silence_start <= self.max_sil_kept * 2:
|
126 |
-
pos = rms_list[i - self.max_sil_kept : silence_start + self.max_sil_kept + 1].argmin()
|
127 |
-
|
128 |
-
pos += i - self.max_sil_kept
|
129 |
-
|
130 |
-
pos_l = (rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start)
|
131 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
132 |
-
|
133 |
-
if silence_start == 0:
|
134 |
-
sil_tags.append((0, pos_r))
|
135 |
-
clip_start = pos_r
|
136 |
-
else:
|
137 |
-
sil_tags.append((min(pos_l, pos), max(pos_r, pos)))
|
138 |
-
clip_start = max(pos_r, pos)
|
139 |
-
else:
|
140 |
-
pos_l = (rms_list[silence_start : silence_start + self.max_sil_kept + 1].argmin() + silence_start)
|
141 |
-
pos_r = (rms_list[i - self.max_sil_kept : i + 1].argmin() + i - self.max_sil_kept)
|
142 |
-
|
143 |
-
if silence_start == 0: sil_tags.append((0, pos_r))
|
144 |
-
else: sil_tags.append((pos_l, pos_r))
|
145 |
-
|
146 |
-
clip_start = pos_r
|
147 |
-
silence_start = None
|
148 |
-
|
149 |
-
total_frames = rms_list.shape[0]
|
150 |
-
|
151 |
-
if (silence_start is not None and total_frames - silence_start >= self.min_interval):
|
152 |
-
silence_end = min(total_frames, silence_start + self.max_sil_kept)
|
153 |
-
pos = rms_list[silence_start : silence_end + 1].argmin() + silence_start
|
154 |
-
sil_tags.append((pos, total_frames + 1))
|
155 |
-
|
156 |
-
if not sil_tags: return [waveform]
|
157 |
-
else:
|
158 |
-
chunks = []
|
159 |
-
|
160 |
-
if sil_tags[0][0] > 0: chunks.append(self._apply_slice(waveform, 0, sil_tags[0][0]))
|
161 |
-
|
162 |
-
for i in range(len(sil_tags) - 1):
|
163 |
-
chunks.append(self._apply_slice(waveform, sil_tags[i][1], sil_tags[i + 1][0]))
|
164 |
-
|
165 |
-
if sil_tags[-1][1] < total_frames: chunks.append(self._apply_slice(waveform, sil_tags[-1][1], total_frames))
|
166 |
-
|
167 |
-
return chunks
|
168 |
-
|
169 |
-
|
170 |
-
def get_rms(y, frame_length=2048, hop_length=512, pad_mode="constant"):
|
171 |
-
padding = (int(frame_length // 2), int(frame_length // 2))
|
172 |
-
y = np.pad(y, padding, mode=pad_mode)
|
173 |
-
|
174 |
-
axis = -1
|
175 |
-
out_strides = y.strides + tuple([y.strides[axis]])
|
176 |
-
x_shape_trimmed = list(y.shape)
|
177 |
-
x_shape_trimmed[axis] -= frame_length - 1
|
178 |
-
out_shape = tuple(x_shape_trimmed) + tuple([frame_length])
|
179 |
-
xw = np.lib.stride_tricks.as_strided(y, shape=out_shape, strides=out_strides)
|
180 |
-
|
181 |
-
target_axis = axis - 1 if axis < 0 else axis + 1
|
182 |
-
|
183 |
-
|
184 |
-
xw = np.moveaxis(xw, -1, target_axis)
|
185 |
-
slices = [slice(None)] * xw.ndim
|
186 |
-
slices[axis] = slice(0, None, hop_length)
|
187 |
-
x = xw[tuple(slices)]
|
188 |
-
|
189 |
-
power = np.mean(np.abs(x) ** 2, axis=-2, keepdims=True)
|
190 |
-
return np.sqrt(power)
|
191 |
-
|
192 |
-
|
193 |
-
class PreProcess:
|
194 |
-
def __init__(self, sr, exp_dir, per):
|
195 |
-
self.slicer = Slicer(sr=sr, threshold=-42, min_length=1500, min_interval=400, hop_size=15, max_sil_kept=500)
|
196 |
-
self.sr = sr
|
197 |
-
self.b_high, self.a_high = signal.butter(N=5, Wn=HIGH_PASS_CUTOFF, btype="high", fs=self.sr)
|
198 |
-
self.per = per
|
199 |
-
self.exp_dir = exp_dir
|
200 |
-
self.device = "cpu"
|
201 |
-
self.gt_wavs_dir = os.path.join(exp_dir, "sliced_audios")
|
202 |
-
self.wavs16k_dir = os.path.join(exp_dir, "sliced_audios_16k")
|
203 |
-
|
204 |
-
os.makedirs(self.gt_wavs_dir, exist_ok=True)
|
205 |
-
os.makedirs(self.wavs16k_dir, exist_ok=True)
|
206 |
-
|
207 |
-
def _normalize_audio(self, audio: torch.Tensor):
|
208 |
-
tmp_max = torch.abs(audio).max()
|
209 |
-
if tmp_max > 2.5: return None
|
210 |
-
|
211 |
-
return (audio / tmp_max * (MAX_AMPLITUDE * ALPHA)) + (1 - ALPHA) * audio
|
212 |
-
|
213 |
-
def process_audio_segment(self, normalized_audio: np.ndarray, sid, idx0, idx1):
|
214 |
-
if normalized_audio is None:
|
215 |
-
logs(f"{sid}-{idx0}-{idx1}-filtered")
|
216 |
-
return
|
217 |
-
|
218 |
-
wavfile.write(os.path.join(self.gt_wavs_dir, f"{sid}_{idx0}_{idx1}.wav"), self.sr, normalized_audio.astype(np.float32))
|
219 |
-
audio_16k = librosa.resample(normalized_audio, orig_sr=self.sr, target_sr=SAMPLE_RATE_16K)
|
220 |
-
wavfile.write(os.path.join(self.wavs16k_dir, f"{sid}_{idx0}_{idx1}.wav"), SAMPLE_RATE_16K, audio_16k.astype(np.float32))
|
221 |
-
|
222 |
-
def process_audio(self, path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
223 |
-
try:
|
224 |
-
audio = load_audio(path, self.sr)
|
225 |
-
|
226 |
-
if process_effects:
|
227 |
-
audio = signal.lfilter(self.b_high, self.a_high, audio)
|
228 |
-
audio = self._normalize_audio(audio)
|
229 |
-
|
230 |
-
if clean_dataset: audio = nr.reduce_noise(y=audio, sr=self.sr, prop_decrease=clean_strength)
|
231 |
-
|
232 |
-
idx1 = 0
|
233 |
-
|
234 |
-
if cut_preprocess:
|
235 |
-
for audio_segment in self.slicer.slice(audio):
|
236 |
-
i = 0
|
237 |
-
|
238 |
-
while 1:
|
239 |
-
start = int(self.sr * (self.per - OVERLAP) * i)
|
240 |
-
i += 1
|
241 |
-
|
242 |
-
if len(audio_segment[start:]) > (self.per + OVERLAP) * self.sr:
|
243 |
-
tmp_audio = audio_segment[start : start + int(self.per * self.sr)]
|
244 |
-
self.process_audio_segment(tmp_audio, sid, idx0, idx1)
|
245 |
-
idx1 += 1
|
246 |
-
else:
|
247 |
-
tmp_audio = audio_segment[start:]
|
248 |
-
self.process_audio_segment(tmp_audio, sid, idx0, idx1)
|
249 |
-
idx1 += 1
|
250 |
-
break
|
251 |
-
|
252 |
-
else: self.process_audio_segment(audio, sid, idx0, idx1)
|
253 |
-
except Exception as e:
|
254 |
-
raise RuntimeError(f"{translations['process_audio_error']}: {e}")
|
255 |
-
|
256 |
-
def process_file(args):
|
257 |
-
pp, file, cut_preprocess, process_effects, clean_dataset, clean_strength = (args)
|
258 |
-
file_path, idx0, sid = file
|
259 |
-
|
260 |
-
pp.process_audio(file_path, idx0, sid, cut_preprocess, process_effects, clean_dataset, clean_strength)
|
261 |
-
|
262 |
-
def preprocess_training_set(input_root, sr, num_processes, exp_dir, per, cut_preprocess, process_effects, clean_dataset, clean_strength):
|
263 |
-
start_time = time.time()
|
264 |
-
|
265 |
-
pp = PreProcess(sr, exp_dir, per)
|
266 |
-
logger.info(translations["start_preprocess"].format(num_processes=num_processes))
|
267 |
-
|
268 |
-
files = []
|
269 |
-
idx = 0
|
270 |
-
|
271 |
-
for root, _, filenames in os.walk(input_root):
|
272 |
-
try:
|
273 |
-
sid = 0 if root == input_root else int(os.path.basename(root))
|
274 |
-
for f in filenames:
|
275 |
-
if f.lower().endswith((".wav", ".mp3", ".flac", ".ogg")):
|
276 |
-
files.append((os.path.join(root, f), idx, sid))
|
277 |
-
idx += 1
|
278 |
-
except ValueError:
|
279 |
-
raise ValueError(f"{translations['not_integer']} '{os.path.basename(root)}'.")
|
280 |
-
|
281 |
-
with tqdm(total=len(files), desc=translations["preprocess"]) as pbar:
|
282 |
-
with ProcessPoolExecutor(max_workers=num_processes) as executor:
|
283 |
-
futures = [
|
284 |
-
executor.submit(
|
285 |
-
process_file,
|
286 |
-
(
|
287 |
-
pp,
|
288 |
-
file,
|
289 |
-
cut_preprocess,
|
290 |
-
process_effects,
|
291 |
-
clean_dataset,
|
292 |
-
clean_strength,
|
293 |
-
),
|
294 |
-
)
|
295 |
-
for file in files
|
296 |
-
]
|
297 |
-
for future in as_completed(futures):
|
298 |
-
try:
|
299 |
-
future.result()
|
300 |
-
except Exception as e:
|
301 |
-
raise RuntimeError(f"{translations['process_error']}: {e}")
|
302 |
-
|
303 |
-
pbar.update(1)
|
304 |
-
|
305 |
-
elapsed_time = time.time() - start_time
|
306 |
-
logger.info(translations["preprocess_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
307 |
-
|
308 |
-
if __name__ == "__main__":
|
309 |
-
args = parse_arguments()
|
310 |
-
|
311 |
-
experiment_directory = os.path.join("assets", "logs", args.model_name)
|
312 |
-
num_processes = args.cpu_cores
|
313 |
-
num_processes = multiprocessing.cpu_count() if num_processes is None else int(num_processes)
|
314 |
-
dataset = args.dataset_path
|
315 |
-
sample_rate = args.sample_rate
|
316 |
-
cut_preprocess = args.cut_preprocess
|
317 |
-
preprocess_effects = args.process_effects
|
318 |
-
clean_dataset = args.clean_dataset
|
319 |
-
clean_strength = args.clean_strength
|
320 |
-
|
321 |
-
os.makedirs(experiment_directory, exist_ok=True)
|
322 |
-
|
323 |
-
if len([f for f in os.listdir(os.path.join(dataset)) if os.path.isfile(os.path.join(dataset, f)) and f.lower().endswith((".wav", ".mp3", ".flac", ".ogg"))]) < 1: raise FileNotFoundError("Không tìm thấy dữ liệu")
|
324 |
-
|
325 |
-
log_file = os.path.join(experiment_directory, "preprocess.log")
|
326 |
-
|
327 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
328 |
-
else:
|
329 |
-
console_handler = logging.StreamHandler()
|
330 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
331 |
-
|
332 |
-
console_handler.setFormatter(console_formatter)
|
333 |
-
console_handler.setLevel(logging.INFO)
|
334 |
-
|
335 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
336 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
337 |
-
|
338 |
-
file_handler.setFormatter(file_formatter)
|
339 |
-
file_handler.setLevel(logging.DEBUG)
|
340 |
-
|
341 |
-
logger.addHandler(console_handler)
|
342 |
-
logger.addHandler(file_handler)
|
343 |
-
logger.setLevel(logging.DEBUG)
|
344 |
-
|
345 |
-
logger.debug(f"{translations['modelname']}: {args.model_name}")
|
346 |
-
logger.debug(f"{translations['export_process']}: {experiment_directory}")
|
347 |
-
logger.debug(f"{translations['dataset_folder']}: {dataset}")
|
348 |
-
logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
|
349 |
-
logger.debug(f"{translations['cpu_core']}: {num_processes}")
|
350 |
-
logger.debug(f"{translations['split_audio']}: {cut_preprocess}")
|
351 |
-
logger.debug(f"{translations['preprocess_effect']}: {preprocess_effects}")
|
352 |
-
logger.debug(f"{translations['clear_audio']}: {clean_dataset}")
|
353 |
-
if clean_dataset: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
|
354 |
-
|
355 |
-
try:
|
356 |
-
preprocess_training_set(dataset, sample_rate, num_processes, experiment_directory, per, cut_preprocess, preprocess_effects, clean_dataset, clean_strength)
|
357 |
-
except Exception as e:
|
358 |
-
logger.error(f"{translations['process_audio_error']} {e}")
|
359 |
-
|
360 |
-
logger.info(f"{translations['preprocess_model_success']} {args.model_name}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/separator_music.py
DELETED
@@ -1,400 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import logging
|
5 |
-
import argparse
|
6 |
-
import logging.handlers
|
7 |
-
|
8 |
-
import soundfile as sf
|
9 |
-
import noisereduce as nr
|
10 |
-
|
11 |
-
from pydub import AudioSegment
|
12 |
-
from distutils.util import strtobool
|
13 |
-
|
14 |
-
now_dir = os.getcwd()
|
15 |
-
sys.path.append(now_dir)
|
16 |
-
|
17 |
-
from main.configs.config import Config
|
18 |
-
from main.library.algorithm.separator import Separator
|
19 |
-
|
20 |
-
translations = Config().translations
|
21 |
-
|
22 |
-
log_file = os.path.join("assets", "logs", "separator.log")
|
23 |
-
logger = logging.getLogger(__name__)
|
24 |
-
|
25 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
26 |
-
else:
|
27 |
-
console_handler = logging.StreamHandler()
|
28 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
29 |
-
|
30 |
-
console_handler.setFormatter(console_formatter)
|
31 |
-
console_handler.setLevel(logging.INFO)
|
32 |
-
|
33 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
34 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
35 |
-
|
36 |
-
file_handler.setFormatter(file_formatter)
|
37 |
-
file_handler.setLevel(logging.DEBUG)
|
38 |
-
|
39 |
-
logger.addHandler(console_handler)
|
40 |
-
logger.addHandler(file_handler)
|
41 |
-
logger.setLevel(logging.DEBUG)
|
42 |
-
|
43 |
-
demucs_models = {
|
44 |
-
"HT-Tuned": "htdemucs_ft.yaml",
|
45 |
-
"HT-Normal": "htdemucs.yaml",
|
46 |
-
"HD_MMI": "hdemucs_mmi.yaml",
|
47 |
-
"HT_6S": "htdemucs_6s.yaml"
|
48 |
-
}
|
49 |
-
|
50 |
-
mdx_models = {
|
51 |
-
"Main_340": "UVR-MDX-NET_Main_340.onnx",
|
52 |
-
"Main_390": "UVR-MDX-NET_Main_390.onnx",
|
53 |
-
"Main_406": "UVR-MDX-NET_Main_406.onnx",
|
54 |
-
"Main_427": "UVR-MDX-NET_Main_427.onnx",
|
55 |
-
"Main_438": "UVR-MDX-NET_Main_438.onnx",
|
56 |
-
"Inst_full_292": "UVR-MDX-NET-Inst_full_292.onnx",
|
57 |
-
"Inst_HQ_1": "UVR-MDX-NET_Inst_HQ_1.onnx",
|
58 |
-
"Inst_HQ_2": "UVR-MDX-NET_Inst_HQ_2.onnx",
|
59 |
-
"Inst_HQ_3": "UVR-MDX-NET_Inst_HQ_3.onnx",
|
60 |
-
"Inst_HQ_4": "UVR-MDX-NET-Inst_HQ_4.onnx",
|
61 |
-
"Kim_Vocal_1": "Kim_Vocal_1.onnx",
|
62 |
-
"Kim_Vocal_2": "Kim_Vocal_2.onnx",
|
63 |
-
"Kim_Inst": "Kim_Inst.onnx",
|
64 |
-
"Inst_187_beta": "UVR-MDX-NET_Inst_187_beta.onnx",
|
65 |
-
"Inst_82_beta": "UVR-MDX-NET_Inst_82_beta.onnx",
|
66 |
-
"Inst_90_beta": "UVR-MDX-NET_Inst_90_beta.onnx",
|
67 |
-
"Voc_FT": "UVR-MDX-NET-Voc_FT.onnx",
|
68 |
-
"Crowd_HQ": "UVR-MDX-NET_Crowd_HQ_1.onnx",
|
69 |
-
"MDXNET_9482": "UVR_MDXNET_9482.onnx",
|
70 |
-
"Inst_1": "UVR-MDX-NET-Inst_1.onnx",
|
71 |
-
"Inst_2": "UVR-MDX-NET-Inst_2.onnx",
|
72 |
-
"Inst_3": "UVR-MDX-NET-Inst_3.onnx",
|
73 |
-
"MDXNET_1_9703": "UVR_MDXNET_1_9703.onnx",
|
74 |
-
"MDXNET_2_9682": "UVR_MDXNET_2_9682.onnx",
|
75 |
-
"MDXNET_3_9662": "UVR_MDXNET_3_9662.onnx",
|
76 |
-
"Inst_Main": "UVR-MDX-NET-Inst_Main.onnx",
|
77 |
-
"MDXNET_Main": "UVR_MDXNET_Main.onnx"
|
78 |
-
}
|
79 |
-
|
80 |
-
kara_models = {
|
81 |
-
"Version-1": "UVR_MDXNET_KARA.onnx",
|
82 |
-
"Version-2": "UVR_MDXNET_KARA_2.onnx"
|
83 |
-
}
|
84 |
-
|
85 |
-
|
86 |
-
def parse_arguments() -> tuple:
|
87 |
-
parser = argparse.ArgumentParser()
|
88 |
-
parser.add_argument("--input_path", type=str, required=True)
|
89 |
-
parser.add_argument("--output_path", type=str, default="./audios")
|
90 |
-
parser.add_argument("--format", type=str, default="wav")
|
91 |
-
parser.add_argument("--shifts", type=int, default=10)
|
92 |
-
parser.add_argument("--segments_size", type=int, default=256)
|
93 |
-
parser.add_argument("--overlap", type=float, default=0.25)
|
94 |
-
parser.add_argument("--mdx_hop_length", type=int, default=1024)
|
95 |
-
parser.add_argument("--mdx_batch_size", type=int, default=1)
|
96 |
-
parser.add_argument("--clean_audio", type=lambda x: bool(strtobool(x)), default=False)
|
97 |
-
parser.add_argument("--clean_strength", type=float, default=0.7)
|
98 |
-
parser.add_argument("--backing_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
99 |
-
parser.add_argument("--demucs_model", type=str, default="HT-Normal")
|
100 |
-
parser.add_argument("--kara_model", type=str, default="Version-1")
|
101 |
-
parser.add_argument("--mdx_model", type=str, default="Main_340")
|
102 |
-
parser.add_argument("--backing", type=lambda x: bool(strtobool(x)), default=False)
|
103 |
-
parser.add_argument("--mdx", type=lambda x: bool(strtobool(x)), default=False)
|
104 |
-
parser.add_argument("--mdx_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
105 |
-
parser.add_argument("--reverb", type=lambda x: bool(strtobool(x)), default=False)
|
106 |
-
parser.add_argument("--reverb_denoise", type=lambda x: bool(strtobool(x)), default=False)
|
107 |
-
parser.add_argument("--backing_reverb", type=lambda x: bool(strtobool(x)), default=False)
|
108 |
-
|
109 |
-
args = parser.parse_args()
|
110 |
-
return args
|
111 |
-
|
112 |
-
def main():
|
113 |
-
start_time = time.time()
|
114 |
-
|
115 |
-
try:
|
116 |
-
args = parse_arguments()
|
117 |
-
|
118 |
-
input_path = args.input_path
|
119 |
-
output_path = args.output_path
|
120 |
-
export_format = args.format
|
121 |
-
shifts = args.shifts
|
122 |
-
segments_size = args.segments_size
|
123 |
-
overlap = args.overlap
|
124 |
-
hop_length = args.mdx_hop_length
|
125 |
-
batch_size = args.mdx_batch_size
|
126 |
-
clean_audio = args.clean_audio
|
127 |
-
clean_strength = args.clean_strength
|
128 |
-
backing_denoise = args.backing_denoise
|
129 |
-
demucs_model = args.demucs_model
|
130 |
-
kara_model = args.kara_model
|
131 |
-
backing = args.backing
|
132 |
-
mdx = args.mdx
|
133 |
-
mdx_model = args.mdx_model
|
134 |
-
mdx_denoise = args.mdx_denoise
|
135 |
-
reverb = args.reverb
|
136 |
-
reverb_denoise = args.reverb_denoise
|
137 |
-
backing_reverb = args.backing_reverb
|
138 |
-
|
139 |
-
if backing_reverb and not reverb:
|
140 |
-
logger.warning(translations["turn_on_dereverb"])
|
141 |
-
sys.exit(1)
|
142 |
-
|
143 |
-
if backing_reverb and not backing:
|
144 |
-
logger.warning(translations["turn_on_separator_backing"])
|
145 |
-
sys.exit(1)
|
146 |
-
|
147 |
-
logger.debug(f"{translations['audio_path']}: {input_path}")
|
148 |
-
logger.debug(f"{translations['output_path']}: {output_path}")
|
149 |
-
logger.debug(f"{translations['export_format']}: {export_format}")
|
150 |
-
if not mdx: logger.debug(f"{translations['shift']}: {shifts}")
|
151 |
-
logger.debug(f"{translations['segments_size']}: {segments_size}")
|
152 |
-
logger.debug(f"{translations['overlap']}: {overlap}")
|
153 |
-
if clean_audio: logger.debug(f"{translations['clear_audio']}: {clean_audio}")
|
154 |
-
if clean_audio: logger.debug(f"{translations['clean_strength']}: {clean_strength}")
|
155 |
-
if not mdx: logger.debug(f"{translations['demucs_model']}: {demucs_model}")
|
156 |
-
if backing: logger.debug(f"{translations['denoise_backing']}: {backing_denoise}")
|
157 |
-
if backing: logger.debug(f"{translations['backing_model_ver']}: {kara_model}")
|
158 |
-
if backing: logger.debug(f"{translations['separator_backing']}: {backing}")
|
159 |
-
if mdx: logger.debug(f"{translations['use_mdx']}: {mdx}")
|
160 |
-
if mdx: logger.debug(f"{translations['mdx_model']}: {mdx_model}")
|
161 |
-
if mdx: logger.debug(f"{translations['denoise_mdx']}: {mdx_denoise}")
|
162 |
-
if mdx or backing or reverb: logger.debug(f"Hop length: {hop_length}")
|
163 |
-
if mdx or backing or reverb: logger.debug(f"{translations['batch_size']}: {batch_size}")
|
164 |
-
if reverb: logger.debug(f"{translations['dereveb_audio']}: {reverb}")
|
165 |
-
if reverb: logger.debug(f"{translations['denoise_dereveb']}: {reverb_denoise}")
|
166 |
-
if reverb: logger.debug(f"{translations['dereveb_backing']}: {backing_reverb}")
|
167 |
-
|
168 |
-
if not mdx: vocals, instruments = separator_music_demucs(input_path, output_path, export_format, shifts, overlap, segments_size, demucs_model)
|
169 |
-
else: vocals, instruments = separator_music_mdx(input_path, output_path, export_format, segments_size, overlap, mdx_denoise, mdx_model, hop_length, batch_size)
|
170 |
-
|
171 |
-
if backing: main_vocals, backing_vocals = separator_backing(vocals, output_path, export_format, segments_size, overlap, backing_denoise, kara_model, hop_length, batch_size)
|
172 |
-
if reverb: vocals_no_reverb, main_vocals_no_reverb, backing_vocals_no_reverb = separator_reverb(output_path, export_format, segments_size, overlap, reverb_denoise, reverb, backing, backing_reverb, hop_length, batch_size)
|
173 |
-
|
174 |
-
original_output = os.path.join(output_path, f"Original_Vocals_No_Reverb.{export_format}") if reverb else os.path.join(output_path, f"Original_Vocals.{export_format}")
|
175 |
-
main_output = os.path.join(output_path, f"Main_Vocals_No_Reverb.{export_format}") if reverb and backing else os.path.join(output_path, f"Main_Vocals.{export_format}")
|
176 |
-
backing_output = os.path.join(output_path, f"Backing_Vocals_No_Reverb.{export_format}") if reverb and backing_reverb else os.path.join(output_path, f"Backing_Vocals.{export_format}")
|
177 |
-
|
178 |
-
if clean_audio:
|
179 |
-
logger.info(f"{translations['clear_audio']}...")
|
180 |
-
vocal_data, vocal_sr = sf.read(vocals_no_reverb if reverb else vocals)
|
181 |
-
main_data, main_sr = sf.read(main_vocals_no_reverb if reverb and backing else main_vocals)
|
182 |
-
backing_data, backing_sr = sf.read(backing_vocals_no_reverb if reverb and backing_reverb else backing_vocals)
|
183 |
-
|
184 |
-
vocals_clean = nr.reduce_noise(y=vocal_data, prop_decrease=clean_strength)
|
185 |
-
|
186 |
-
sf.write(original_output, vocals_clean, vocal_sr, format=export_format)
|
187 |
-
|
188 |
-
if backing:
|
189 |
-
mains_clean = nr.reduce_noise(y=main_data, sr=main_sr, prop_decrease=clean_strength)
|
190 |
-
backing_clean = nr.reduce_noise(y=backing_data, sr=backing_sr, prop_decrease=clean_strength)
|
191 |
-
sf.write(main_output, mains_clean, main_sr, format=export_format)
|
192 |
-
sf.write(backing_output, backing_clean, backing_sr, format=export_format)
|
193 |
-
|
194 |
-
logger.info(translations["clean_audio_success"])
|
195 |
-
return original_output, instruments, main_output, backing_output
|
196 |
-
except Exception as e:
|
197 |
-
logger.error(f"{translations['separator_error']}: {e}")
|
198 |
-
return None, None, None, None
|
199 |
-
|
200 |
-
elapsed_time = time.time() - start_time
|
201 |
-
logger.info(translations["separator_success"].format(elapsed_time=f"{elapsed_time:.2f}"))
|
202 |
-
|
203 |
-
|
204 |
-
def separator_music_demucs(input, output, format, shifts, overlap, segments_size, demucs_model):
|
205 |
-
if not os.path.exists(input):
|
206 |
-
logger.warning(translations["input_not_valid"])
|
207 |
-
sys.exit(1)
|
208 |
-
|
209 |
-
if not os.path.exists(output):
|
210 |
-
logger.warning(translations["output_not_valid"])
|
211 |
-
sys.exit(1)
|
212 |
-
|
213 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
214 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
215 |
-
|
216 |
-
model = demucs_models.get(demucs_model)
|
217 |
-
|
218 |
-
segment_size = segments_size / 2
|
219 |
-
|
220 |
-
logger.info(f"{translations['separator_process_2']}...")
|
221 |
-
|
222 |
-
demucs_output = separator_main(audio_file=input, model_filename=model, output_format=format, output_dir=output, demucs_segment_size=segment_size, demucs_shifts=shifts, demucs_overlap=overlap)
|
223 |
-
|
224 |
-
for f in demucs_output:
|
225 |
-
path = os.path.join(output, f)
|
226 |
-
|
227 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
228 |
-
|
229 |
-
if '_(Drums)_' in f: drums = path
|
230 |
-
elif '_(Bass)_' in f: bass = path
|
231 |
-
elif '_(Other)_' in f: other = path
|
232 |
-
elif '_(Vocals)_' in f: os.rename(path, os.path.join(output, f"Original_Vocals.{format}"))
|
233 |
-
|
234 |
-
AudioSegment.from_file(drums).overlay(AudioSegment.from_file(bass)).overlay(AudioSegment.from_file(other)).export(os.path.join(output, f"Instruments.{format}"), format=format)
|
235 |
-
|
236 |
-
for f in [drums, bass, other]:
|
237 |
-
if os.path.exists(f): os.remove(f)
|
238 |
-
|
239 |
-
logger.info(translations["separator_success_2"])
|
240 |
-
return os.path.join(output, f"Original_Vocals.{format}"), os.path.join(output, f"Instruments.{format}")
|
241 |
-
|
242 |
-
def separator_backing(input, output, format, segments_size, overlap, denoise, kara_model, hop_length, batch_size):
|
243 |
-
if not os.path.exists(input):
|
244 |
-
logger.warning(translations["input_not_valid"])
|
245 |
-
sys.exit(1)
|
246 |
-
|
247 |
-
if not os.path.exists(output):
|
248 |
-
logger.warning(translations["output_not_valid"])
|
249 |
-
sys.exit(1)
|
250 |
-
|
251 |
-
for f in [f"Main_Vocals.{format}", f"Backing_Vocals.{format}"]:
|
252 |
-
if os.path.exists(os.path.join(output, f)): os.remove(os.path.join(output, f))
|
253 |
-
|
254 |
-
model_2 = kara_models.get(kara_model)
|
255 |
-
|
256 |
-
logger.info(f"{translations['separator_process_backing']}...")
|
257 |
-
backing_outputs = separator_main(audio_file=input, model_filename=model_2, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
|
258 |
-
|
259 |
-
main_output = os.path.join(output, f"Main_Vocals.{format}")
|
260 |
-
backing_output = os.path.join(output, f"Backing_Vocals.{format}")
|
261 |
-
|
262 |
-
for f in backing_outputs:
|
263 |
-
path = os.path.join(output, f)
|
264 |
-
|
265 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
266 |
-
|
267 |
-
if '_(Instrumental)_' in f: os.rename(path, backing_output)
|
268 |
-
elif '_(Vocals)_' in f: os.rename(path, main_output)
|
269 |
-
|
270 |
-
logger.info(translations["separator_process_backing_success"])
|
271 |
-
return main_output, backing_output
|
272 |
-
|
273 |
-
def separator_music_mdx(input, output, format, segments_size, overlap, denoise, mdx_model, hop_length, batch_size):
|
274 |
-
if not os.path.exists(input):
|
275 |
-
logger.warning(translations["input_not_valid"])
|
276 |
-
sys.exit(1)
|
277 |
-
|
278 |
-
if not os.path.exists(output):
|
279 |
-
logger.warning(translations["output_not_valid"])
|
280 |
-
sys.exit(1)
|
281 |
-
|
282 |
-
for i in [f"Original_Vocals.{format}", f"Instruments.{format}"]:
|
283 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
284 |
-
|
285 |
-
model_3 = mdx_models.get(mdx_model)
|
286 |
-
|
287 |
-
logger.info(f"{translations['separator_process_2']}...")
|
288 |
-
output_music = separator_main(audio_file=input, model_filename=model_3, output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
|
289 |
-
|
290 |
-
original_output = os.path.join(output, f"Original_Vocals.{format}")
|
291 |
-
instruments_output = os.path.join(output, f"Instruments.{format}")
|
292 |
-
|
293 |
-
for f in output_music:
|
294 |
-
path = os.path.join(output, f)
|
295 |
-
|
296 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
297 |
-
|
298 |
-
if '_(Instrumental)_' in f: os.rename(path, instruments_output)
|
299 |
-
elif '_(Vocals)_' in f: os.rename(path, original_output)
|
300 |
-
|
301 |
-
logger.info(translations["separator_process_backing_success"])
|
302 |
-
return original_output, instruments_output
|
303 |
-
|
304 |
-
def separator_reverb(output, format, segments_size, overlap, denoise, original, main, backing_reverb, hop_length, batch_size):
|
305 |
-
if not os.path.exists(output):
|
306 |
-
logger.warning(translations["output_not_valid"])
|
307 |
-
sys.exit(1)
|
308 |
-
|
309 |
-
for i in [f"Original_Vocals_Reverb.{format}", f"Main_Vocals_Reverb.{format}", f"Original_Vocals_No_Reverb.{format}", f"Main_Vocals_No_Reverb.{format}"]:
|
310 |
-
if os.path.exists(os.path.join(output, i)): os.remove(os.path.join(output, i))
|
311 |
-
|
312 |
-
dereveb_path = []
|
313 |
-
|
314 |
-
if original:
|
315 |
-
try:
|
316 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Original_Vocals' in f][0]))
|
317 |
-
except IndexError:
|
318 |
-
logger.warning(translations["not_found_original_vocal"])
|
319 |
-
sys.exit(1)
|
320 |
-
|
321 |
-
if main:
|
322 |
-
try:
|
323 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Main_Vocals' in f][0]))
|
324 |
-
except IndexError:
|
325 |
-
logger.warning(translations["not_found_main_vocal"])
|
326 |
-
sys.exit(1)
|
327 |
-
|
328 |
-
if backing_reverb:
|
329 |
-
try:
|
330 |
-
dereveb_path.append(os.path.join(output, [f for f in os.listdir(output) if 'Backing_Vocals' in f][0]))
|
331 |
-
except IndexError:
|
332 |
-
logger.warning(translations["not_found_backing_vocal"])
|
333 |
-
sys.exit(1)
|
334 |
-
|
335 |
-
for path in dereveb_path:
|
336 |
-
if not os.path.exists(path):
|
337 |
-
logger.warning(translations["not_found"].format(name=path))
|
338 |
-
sys.exit(1)
|
339 |
-
|
340 |
-
if "Original_Vocals" in path:
|
341 |
-
reverb_path = os.path.join(output, f"Original_Vocals_Reverb.{format}")
|
342 |
-
no_reverb_path = os.path.join(output, f"Original_Vocals_No_Reverb.{format}")
|
343 |
-
start_title = translations["process_original"]
|
344 |
-
end_title = translations["process_original_success"]
|
345 |
-
elif "Main_Vocals" in path:
|
346 |
-
reverb_path = os.path.join(output, f"Main_Vocals_Reverb.{format}")
|
347 |
-
no_reverb_path = os.path.join(output, f"Main_Vocals_No_Reverb.{format}")
|
348 |
-
start_title = translations["process_main"]
|
349 |
-
end_title = translations["process_main_success"]
|
350 |
-
elif "Backing_Vocals" in path:
|
351 |
-
reverb_path = os.path.join(output, f"Backing_Vocals_Reverb.{format}")
|
352 |
-
no_reverb_path = os.path.join(output, f"Backing_Vocals_No_Reverb.{format}")
|
353 |
-
start_title = translations["process_backing"]
|
354 |
-
end_title = translations["process_backing_success"]
|
355 |
-
|
356 |
-
logger.info(start_title)
|
357 |
-
output_dereveb = separator_main(audio_file=path, model_filename="Reverb_HQ_By_FoxJoy.onnx", output_format=format, output_dir=output, mdx_segment_size=segments_size, mdx_overlap=overlap, mdx_batch_size=batch_size, mdx_hop_length=hop_length, mdx_enable_denoise=denoise)
|
358 |
-
|
359 |
-
for f in output_dereveb:
|
360 |
-
path = os.path.join(output, f)
|
361 |
-
|
362 |
-
if not os.path.exists(path): logger.error(translations["not_found"].format(name=path))
|
363 |
-
|
364 |
-
if '_(Reverb)_' in f: os.rename(path, reverb_path)
|
365 |
-
elif '_(No Reverb)_' in f: os.rename(path, no_reverb_path)
|
366 |
-
|
367 |
-
logger.info(end_title)
|
368 |
-
|
369 |
-
return (os.path.join(output, f"Original_Vocals_No_Reverb.{format}") if original else None), (os.path.join(output, f"Main_Vocals_No_Reverb.{format}") if main else None), (os.path.join(output, f"Backing_Vocals_No_Reverb.{format}") if backing_reverb else None)
|
370 |
-
|
371 |
-
def separator_main(audio_file=None, model_filename="UVR-MDX-NET_Main_340.onnx", output_format="wav", output_dir=".", mdx_segment_size=256, mdx_overlap=0.25, mdx_batch_size=1, mdx_hop_length=1024, mdx_enable_denoise=True, demucs_segment_size=256, demucs_shifts=2, demucs_overlap=0.25):
|
372 |
-
separator = Separator(
|
373 |
-
log_formatter=file_formatter,
|
374 |
-
log_level=logging.INFO,
|
375 |
-
output_dir=output_dir,
|
376 |
-
output_format=output_format,
|
377 |
-
output_bitrate=None,
|
378 |
-
normalization_threshold=0.9,
|
379 |
-
output_single_stem=None,
|
380 |
-
invert_using_spec=False,
|
381 |
-
sample_rate=44100,
|
382 |
-
mdx_params={
|
383 |
-
"hop_length": mdx_hop_length,
|
384 |
-
"segment_size": mdx_segment_size,
|
385 |
-
"overlap": mdx_overlap,
|
386 |
-
"batch_size": mdx_batch_size,
|
387 |
-
"enable_denoise": mdx_enable_denoise,
|
388 |
-
},
|
389 |
-
demucs_params={
|
390 |
-
"segment_size": demucs_segment_size,
|
391 |
-
"shifts": demucs_shifts,
|
392 |
-
"overlap": demucs_overlap,
|
393 |
-
"segments_enabled": True,
|
394 |
-
}
|
395 |
-
)
|
396 |
-
|
397 |
-
separator.load_model(model_filename=model_filename)
|
398 |
-
return separator.separate(audio_file)
|
399 |
-
|
400 |
-
if __name__ == "__main__": main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/inference/train.py
DELETED
@@ -1,1600 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import glob
|
5 |
-
import json
|
6 |
-
import torch
|
7 |
-
import logging
|
8 |
-
import hashlib
|
9 |
-
import argparse
|
10 |
-
import datetime
|
11 |
-
import warnings
|
12 |
-
import logging.handlers
|
13 |
-
|
14 |
-
import numpy as np
|
15 |
-
import torch.utils.data
|
16 |
-
import matplotlib.pyplot as plt
|
17 |
-
import torch.distributed as dist
|
18 |
-
import torch.multiprocessing as mp
|
19 |
-
|
20 |
-
from tqdm import tqdm
|
21 |
-
from time import time as ttime
|
22 |
-
from scipy.io.wavfile import read
|
23 |
-
from collections import OrderedDict
|
24 |
-
from random import randint, shuffle
|
25 |
-
|
26 |
-
from torch.nn import functional as F
|
27 |
-
from distutils.util import strtobool
|
28 |
-
from torch.utils.data import DataLoader
|
29 |
-
from torch.cuda.amp import GradScaler, autocast
|
30 |
-
from torch.utils.tensorboard import SummaryWriter
|
31 |
-
from librosa.filters import mel as librosa_mel_fn
|
32 |
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
33 |
-
from torch.nn.utils.parametrizations import spectral_norm, weight_norm
|
34 |
-
|
35 |
-
current_dir = os.getcwd()
|
36 |
-
sys.path.append(current_dir)
|
37 |
-
|
38 |
-
from main.configs.config import Config
|
39 |
-
from main.library.algorithm.residuals import LRELU_SLOPE
|
40 |
-
from main.library.algorithm.synthesizers import Synthesizer
|
41 |
-
from main.library.algorithm.commons import get_padding, slice_segments, clip_grad_value
|
42 |
-
|
43 |
-
warnings.filterwarnings("ignore", category=FutureWarning)
|
44 |
-
warnings.filterwarnings("ignore", category=UserWarning)
|
45 |
-
|
46 |
-
logging.getLogger("torch").setLevel(logging.ERROR)
|
47 |
-
|
48 |
-
MATPLOTLIB_FLAG = False
|
49 |
-
|
50 |
-
translations = Config().translations
|
51 |
-
|
52 |
-
class HParams:
|
53 |
-
def __init__(self, **kwargs):
|
54 |
-
for k, v in kwargs.items():
|
55 |
-
self[k] = HParams(**v) if isinstance(v, dict) else v
|
56 |
-
|
57 |
-
|
58 |
-
def keys(self):
|
59 |
-
return self.__dict__.keys()
|
60 |
-
|
61 |
-
|
62 |
-
def items(self):
|
63 |
-
return self.__dict__.items()
|
64 |
-
|
65 |
-
|
66 |
-
def values(self):
|
67 |
-
return self.__dict__.values()
|
68 |
-
|
69 |
-
|
70 |
-
def __len__(self):
|
71 |
-
return len(self.__dict__)
|
72 |
-
|
73 |
-
|
74 |
-
def __getitem__(self, key):
|
75 |
-
return self.__dict__[key]
|
76 |
-
|
77 |
-
|
78 |
-
def __setitem__(self, key, value):
|
79 |
-
self.__dict__[key] = value
|
80 |
-
|
81 |
-
|
82 |
-
def __contains__(self, key):
|
83 |
-
return key in self.__dict__
|
84 |
-
|
85 |
-
|
86 |
-
def __repr__(self):
|
87 |
-
return repr(self.__dict__)
|
88 |
-
|
89 |
-
def parse_arguments() -> tuple:
|
90 |
-
parser = argparse.ArgumentParser()
|
91 |
-
parser.add_argument("--model_name", type=str, required=True)
|
92 |
-
parser.add_argument("--rvc_version", type=str, default="v2")
|
93 |
-
parser.add_argument("--save_every_epoch", type=int, required=True)
|
94 |
-
parser.add_argument("--save_only_latest", type=lambda x: bool(strtobool(x)), default=True)
|
95 |
-
parser.add_argument("--save_every_weights", type=lambda x: bool(strtobool(x)), default=True)
|
96 |
-
parser.add_argument("--total_epoch", type=int, default=300)
|
97 |
-
parser.add_argument("--sample_rate", type=int, required=True)
|
98 |
-
parser.add_argument("--batch_size", type=int, default=8)
|
99 |
-
parser.add_argument("--gpu", type=str, default="0")
|
100 |
-
parser.add_argument("--pitch_guidance", type=lambda x: bool(strtobool(x)), default=True)
|
101 |
-
parser.add_argument("--g_pretrained_path", type=str, default="")
|
102 |
-
parser.add_argument("--d_pretrained_path", type=str, default="")
|
103 |
-
parser.add_argument("--overtraining_detector", type=lambda x: bool(strtobool(x)), default=False)
|
104 |
-
parser.add_argument("--overtraining_threshold", type=int, default=50)
|
105 |
-
parser.add_argument("--sync_graph", type=lambda x: bool(strtobool(x)), default=False)
|
106 |
-
parser.add_argument("--cache_data_in_gpu", type=lambda x: bool(strtobool(x)), default=False)
|
107 |
-
parser.add_argument("--model_author", type=str)
|
108 |
-
|
109 |
-
args = parser.parse_args()
|
110 |
-
return args
|
111 |
-
|
112 |
-
|
113 |
-
args = parse_arguments()
|
114 |
-
|
115 |
-
model_name = args.model_name
|
116 |
-
save_every_epoch = args.save_every_epoch
|
117 |
-
total_epoch = args.total_epoch
|
118 |
-
pretrainG = args.g_pretrained_path
|
119 |
-
pretrainD = args.d_pretrained_path
|
120 |
-
version = args.rvc_version
|
121 |
-
gpus = args.gpu
|
122 |
-
batch_size = args.batch_size
|
123 |
-
sample_rate = args.sample_rate
|
124 |
-
pitch_guidance = args.pitch_guidance
|
125 |
-
save_only_latest = args.save_only_latest
|
126 |
-
save_every_weights = args.save_every_weights
|
127 |
-
cache_data_in_gpu = args.cache_data_in_gpu
|
128 |
-
overtraining_detector = args.overtraining_detector
|
129 |
-
overtraining_threshold = args.overtraining_threshold
|
130 |
-
sync_graph = args.sync_graph
|
131 |
-
model_author = args.model_author
|
132 |
-
|
133 |
-
experiment_dir = os.path.join(current_dir, "assets", "logs", model_name)
|
134 |
-
config_save_path = os.path.join(experiment_dir, "config.json")
|
135 |
-
|
136 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = gpus.replace("-", ",")
|
137 |
-
n_gpus = len(gpus.split("-"))
|
138 |
-
|
139 |
-
torch.backends.cudnn.deterministic = False
|
140 |
-
torch.backends.cudnn.benchmark = False
|
141 |
-
|
142 |
-
global_step = 0
|
143 |
-
last_loss_gen_all = 0
|
144 |
-
overtrain_save_epoch = 0
|
145 |
-
|
146 |
-
loss_gen_history = []
|
147 |
-
smoothed_loss_gen_history = []
|
148 |
-
loss_disc_history = []
|
149 |
-
smoothed_loss_disc_history = []
|
150 |
-
|
151 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
152 |
-
training_file_path = os.path.join(experiment_dir, "training_data.json")
|
153 |
-
|
154 |
-
with open(config_save_path, "r") as f:
|
155 |
-
config = json.load(f)
|
156 |
-
|
157 |
-
config = HParams(**config)
|
158 |
-
config.data.training_files = os.path.join(experiment_dir, "filelist.txt")
|
159 |
-
|
160 |
-
log_file = os.path.join(experiment_dir, "train.log")
|
161 |
-
|
162 |
-
logger = logging.getLogger(__name__)
|
163 |
-
|
164 |
-
if logger.hasHandlers(): logger.handlers.clear()
|
165 |
-
else:
|
166 |
-
console_handler = logging.StreamHandler()
|
167 |
-
console_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
168 |
-
|
169 |
-
console_handler.setFormatter(console_formatter)
|
170 |
-
console_handler.setLevel(logging.INFO)
|
171 |
-
|
172 |
-
file_handler = logging.handlers.RotatingFileHandler(log_file, maxBytes=5*1024*1024, backupCount=3, encoding='utf-8')
|
173 |
-
file_formatter = logging.Formatter(fmt="\n%(asctime)s.%(msecs)03d | %(levelname)s | %(module)s | %(message)s", datefmt="%Y-%m-%d %H:%M:%S")
|
174 |
-
|
175 |
-
file_handler.setFormatter(file_formatter)
|
176 |
-
file_handler.setLevel(logging.DEBUG)
|
177 |
-
|
178 |
-
logger.addHandler(console_handler)
|
179 |
-
logger.addHandler(file_handler)
|
180 |
-
logger.setLevel(logging.DEBUG)
|
181 |
-
|
182 |
-
logger.debug(f"{translations['modelname']}: {model_name}")
|
183 |
-
logger.debug(translations["save_every_epoch"].format(save_every_epoch=save_every_epoch))
|
184 |
-
logger.debug(translations["total_e"].format(total_epoch=total_epoch))
|
185 |
-
logger.debug(translations["dorg"].format(pretrainG=pretrainG, pretrainD=pretrainD))
|
186 |
-
logger.debug(f"{translations['training_version']}: {version}")
|
187 |
-
logger.debug(f"Gpu: {gpus}")
|
188 |
-
logger.debug(f"{translations['batch_size']}: {batch_size}")
|
189 |
-
logger.debug(f"{translations['pretrain_sr']}: {sample_rate}")
|
190 |
-
logger.debug(f"{translations['training_f0']}: {pitch_guidance}")
|
191 |
-
logger.debug(f"{translations['save_only_latest']}: {save_only_latest}")
|
192 |
-
logger.debug(f"{translations['save_every_weights']}: {save_every_weights}")
|
193 |
-
logger.debug(f"{translations['cache_in_gpu']}: {cache_data_in_gpu}")
|
194 |
-
logger.debug(f"{translations['overtraining_detector']}: {overtraining_detector}")
|
195 |
-
logger.debug(f"{translations['threshold']}: {overtraining_threshold}")
|
196 |
-
logger.debug(f"{translations['sync_graph']}: {sync_graph}")
|
197 |
-
if not model_author: logger.debug(translations["model_author"].format(model_author=model_author))
|
198 |
-
|
199 |
-
def main():
|
200 |
-
global training_file_path, last_loss_gen_all, smoothed_loss_gen_history, loss_gen_history, loss_disc_history, smoothed_loss_disc_history, overtrain_save_epoch, model_author
|
201 |
-
os.environ["MASTER_ADDR"] = "localhost"
|
202 |
-
os.environ["MASTER_PORT"] = str(randint(20000, 55555))
|
203 |
-
|
204 |
-
if torch.cuda.is_available():
|
205 |
-
device = torch.device("cuda")
|
206 |
-
n_gpus = torch.cuda.device_count()
|
207 |
-
elif torch.backends.mps.is_available():
|
208 |
-
device = torch.device("mps")
|
209 |
-
n_gpus = 1
|
210 |
-
else:
|
211 |
-
device = torch.device("cpu")
|
212 |
-
n_gpus = 1
|
213 |
-
|
214 |
-
def start():
|
215 |
-
children = []
|
216 |
-
|
217 |
-
for i in range(n_gpus):
|
218 |
-
subproc = mp.Process(target=run, args=(i, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author))
|
219 |
-
children.append(subproc)
|
220 |
-
subproc.start()
|
221 |
-
|
222 |
-
for i in range(n_gpus):
|
223 |
-
children[i].join()
|
224 |
-
|
225 |
-
def load_from_json(file_path):
|
226 |
-
if os.path.exists(file_path):
|
227 |
-
with open(file_path, "r") as f:
|
228 |
-
data = json.load(f)
|
229 |
-
|
230 |
-
return (
|
231 |
-
data.get("loss_disc_history", []),
|
232 |
-
data.get("smoothed_loss_disc_history", []),
|
233 |
-
data.get("loss_gen_history", []),
|
234 |
-
data.get("smoothed_loss_gen_history", []),
|
235 |
-
)
|
236 |
-
|
237 |
-
return [], [], [], []
|
238 |
-
|
239 |
-
def continue_overtrain_detector(training_file_path):
|
240 |
-
if overtraining_detector:
|
241 |
-
if os.path.exists(training_file_path):
|
242 |
-
(
|
243 |
-
loss_disc_history,
|
244 |
-
smoothed_loss_disc_history,
|
245 |
-
loss_gen_history,
|
246 |
-
smoothed_loss_gen_history,
|
247 |
-
) = load_from_json(training_file_path)
|
248 |
-
|
249 |
-
|
250 |
-
n_gpus = torch.cuda.device_count()
|
251 |
-
|
252 |
-
if not torch.cuda.is_available() and torch.backends.mps.is_available(): n_gpus = 1
|
253 |
-
|
254 |
-
if n_gpus < 1:
|
255 |
-
logger.warning(translations["not_gpu"])
|
256 |
-
n_gpus = 1
|
257 |
-
|
258 |
-
if sync_graph:
|
259 |
-
logger.debug(translations["sync"])
|
260 |
-
custom_total_epoch = 1
|
261 |
-
custom_save_every_weights = True
|
262 |
-
|
263 |
-
start()
|
264 |
-
|
265 |
-
model_config_file = os.path.join(experiment_dir, "config.json")
|
266 |
-
rvc_config_file = os.path.join(current_dir, "main", "configs", version, str(sample_rate) + ".json")
|
267 |
-
|
268 |
-
if not os.path.exists(rvc_config_file): rvc_config_file = os.path.join(current_dir, "main", "configs", "v1", str(sample_rate) + ".json")
|
269 |
-
|
270 |
-
pattern = rf"{os.path.basename(model_name)}_(\d+)e_(\d+)s\.pth"
|
271 |
-
|
272 |
-
for filename in os.listdir(os.path.join("assets", "weights")):
|
273 |
-
match = re.match(pattern, filename)
|
274 |
-
|
275 |
-
if match: steps = int(match.group(2))
|
276 |
-
|
277 |
-
def edit_config(config_file):
|
278 |
-
with open(config_file, "r", encoding="utf8") as json_file:
|
279 |
-
config_data = json.load(json_file)
|
280 |
-
|
281 |
-
config_data["train"]["log_interval"] = steps
|
282 |
-
|
283 |
-
with open(config_file, "w", encoding="utf8") as json_file:
|
284 |
-
json.dump(config_data, json_file, indent=2, separators=(",", ": "), ensure_ascii=False)
|
285 |
-
|
286 |
-
edit_config(model_config_file)
|
287 |
-
edit_config(rvc_config_file)
|
288 |
-
|
289 |
-
for root, dirs, files in os.walk(experiment_dir, topdown=False):
|
290 |
-
for name in files:
|
291 |
-
file_path = os.path.join(root, name)
|
292 |
-
_, file_extension = os.path.splitext(name)
|
293 |
-
|
294 |
-
if file_extension == ".0": os.remove(file_path)
|
295 |
-
elif ("D" in name or "G" in name) and file_extension == ".pth": os.remove(file_path)
|
296 |
-
elif ("added" in name or "trained" in name) and file_extension == ".index": os.remove(file_path)
|
297 |
-
|
298 |
-
for name in dirs:
|
299 |
-
if name == "eval":
|
300 |
-
folder_path = os.path.join(root, name)
|
301 |
-
|
302 |
-
for item in os.listdir(folder_path):
|
303 |
-
item_path = os.path.join(folder_path, item)
|
304 |
-
if os.path.isfile(item_path): os.remove(item_path)
|
305 |
-
|
306 |
-
os.rmdir(folder_path)
|
307 |
-
|
308 |
-
logger.info(translations["sync_success"])
|
309 |
-
custom_total_epoch = total_epoch
|
310 |
-
custom_save_every_weights = save_every_weights
|
311 |
-
|
312 |
-
continue_overtrain_detector(training_file_path)
|
313 |
-
start()
|
314 |
-
else:
|
315 |
-
custom_total_epoch = total_epoch
|
316 |
-
custom_save_every_weights = save_every_weights
|
317 |
-
|
318 |
-
continue_overtrain_detector(training_file_path)
|
319 |
-
start()
|
320 |
-
|
321 |
-
|
322 |
-
def plot_spectrogram_to_numpy(spectrogram):
|
323 |
-
global MATPLOTLIB_FLAG
|
324 |
-
|
325 |
-
if not MATPLOTLIB_FLAG:
|
326 |
-
plt.switch_backend("Agg")
|
327 |
-
MATPLOTLIB_FLAG = True
|
328 |
-
|
329 |
-
fig, ax = plt.subplots(figsize=(10, 2))
|
330 |
-
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
331 |
-
|
332 |
-
plt.colorbar(im, ax=ax)
|
333 |
-
plt.xlabel("Frames")
|
334 |
-
plt.ylabel("Channels")
|
335 |
-
plt.tight_layout()
|
336 |
-
|
337 |
-
fig.canvas.draw()
|
338 |
-
data = np.frombuffer(fig.canvas.tostring_rgb(), dtype=np.uint8)
|
339 |
-
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
340 |
-
plt.close(fig)
|
341 |
-
return data
|
342 |
-
|
343 |
-
|
344 |
-
def summarize(writer, global_step, scalars={}, histograms={}, images={}, audios={}, audio_sample_rate=22050):
|
345 |
-
for k, v in scalars.items():
|
346 |
-
writer.add_scalar(k, v, global_step)
|
347 |
-
|
348 |
-
for k, v in histograms.items():
|
349 |
-
writer.add_histogram(k, v, global_step)
|
350 |
-
|
351 |
-
for k, v in images.items():
|
352 |
-
writer.add_image(k, v, global_step, dataformats="HWC")
|
353 |
-
|
354 |
-
for k, v in audios.items():
|
355 |
-
writer.add_audio(k, v, global_step, audio_sample_rate)
|
356 |
-
|
357 |
-
|
358 |
-
def load_checkpoint(checkpoint_path, model, optimizer=None, load_opt=1):
|
359 |
-
assert os.path.isfile(checkpoint_path), translations["not_found_checkpoint"].format(checkpoint_path=checkpoint_path)
|
360 |
-
|
361 |
-
checkpoint_dict = torch.load(checkpoint_path, map_location="cpu")
|
362 |
-
checkpoint_dict = replace_keys_in_dict(replace_keys_in_dict(checkpoint_dict, ".weight_v", ".parametrizations.weight.original1"), ".weight_g", ".parametrizations.weight.original0")
|
363 |
-
|
364 |
-
model_state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
365 |
-
new_state_dict = {k: checkpoint_dict["model"].get(k, v) for k, v in model_state_dict.items()}
|
366 |
-
|
367 |
-
if hasattr(model, "module"): model.module.load_state_dict(new_state_dict, strict=False)
|
368 |
-
else: model.load_state_dict(new_state_dict, strict=False)
|
369 |
-
|
370 |
-
if optimizer and load_opt == 1: optimizer.load_state_dict(checkpoint_dict.get("optimizer", {}))
|
371 |
-
|
372 |
-
logger.debug(translations["save_checkpoint"].format(checkpoint_path=checkpoint_path, checkpoint_dict=checkpoint_dict['iteration']))
|
373 |
-
|
374 |
-
return (
|
375 |
-
model,
|
376 |
-
optimizer,
|
377 |
-
checkpoint_dict.get("learning_rate", 0),
|
378 |
-
checkpoint_dict["iteration"],
|
379 |
-
)
|
380 |
-
|
381 |
-
|
382 |
-
def save_checkpoint(model, optimizer, learning_rate, iteration, checkpoint_path):
|
383 |
-
state_dict = (model.module.state_dict() if hasattr(model, "module") else model.state_dict())
|
384 |
-
|
385 |
-
checkpoint_data = {
|
386 |
-
"model": state_dict,
|
387 |
-
"iteration": iteration,
|
388 |
-
"optimizer": optimizer.state_dict(),
|
389 |
-
"learning_rate": learning_rate,
|
390 |
-
}
|
391 |
-
|
392 |
-
torch.save(checkpoint_data, checkpoint_path)
|
393 |
-
|
394 |
-
old_version_path = checkpoint_path.replace(".pth", "_old_version.pth")
|
395 |
-
checkpoint_data = replace_keys_in_dict(replace_keys_in_dict(checkpoint_data, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g")
|
396 |
-
|
397 |
-
torch.save(checkpoint_data, old_version_path)
|
398 |
-
|
399 |
-
os.replace(old_version_path, checkpoint_path)
|
400 |
-
logger.info(translations["save_model"].format(checkpoint_path=checkpoint_path, iteration=iteration))
|
401 |
-
|
402 |
-
|
403 |
-
def latest_checkpoint_path(dir_path, regex="G_*.pth"):
|
404 |
-
checkpoints = sorted(glob.glob(os.path.join(dir_path, regex)), key=lambda f: int("".join(filter(str.isdigit, f))))
|
405 |
-
|
406 |
-
return checkpoints[-1] if checkpoints else None
|
407 |
-
|
408 |
-
|
409 |
-
def load_wav_to_torch(full_path):
|
410 |
-
sample_rate, data = read(full_path)
|
411 |
-
|
412 |
-
return torch.FloatTensor(data.astype(np.float32)), sample_rate
|
413 |
-
|
414 |
-
|
415 |
-
def load_filepaths_and_text(filename, split="|"):
|
416 |
-
with open(filename, encoding="utf-8") as f:
|
417 |
-
return [line.strip().split(split) for line in f]
|
418 |
-
|
419 |
-
|
420 |
-
def feature_loss(fmap_r, fmap_g):
|
421 |
-
loss = 0
|
422 |
-
|
423 |
-
for dr, dg in zip(fmap_r, fmap_g):
|
424 |
-
for rl, gl in zip(dr, dg):
|
425 |
-
rl = rl.float().detach()
|
426 |
-
gl = gl.float()
|
427 |
-
loss += torch.mean(torch.abs(rl - gl))
|
428 |
-
|
429 |
-
return loss * 2
|
430 |
-
|
431 |
-
|
432 |
-
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
433 |
-
loss = 0
|
434 |
-
r_losses = []
|
435 |
-
g_losses = []
|
436 |
-
|
437 |
-
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
438 |
-
dr = dr.float()
|
439 |
-
dg = dg.float()
|
440 |
-
r_loss = torch.mean((1 - dr) ** 2)
|
441 |
-
g_loss = torch.mean(dg**2)
|
442 |
-
loss += r_loss + g_loss
|
443 |
-
r_losses.append(r_loss.item())
|
444 |
-
g_losses.append(g_loss.item())
|
445 |
-
|
446 |
-
return loss, r_losses, g_losses
|
447 |
-
|
448 |
-
|
449 |
-
def generator_loss(disc_outputs):
|
450 |
-
loss = 0
|
451 |
-
gen_losses = []
|
452 |
-
|
453 |
-
for dg in disc_outputs:
|
454 |
-
dg = dg.float()
|
455 |
-
l = torch.mean((1 - dg) ** 2)
|
456 |
-
gen_losses.append(l)
|
457 |
-
loss += l
|
458 |
-
|
459 |
-
return loss, gen_losses
|
460 |
-
|
461 |
-
|
462 |
-
def kl_loss(z_p, logs_q, m_p, logs_p, z_mask):
|
463 |
-
z_p = z_p.float()
|
464 |
-
logs_q = logs_q.float()
|
465 |
-
m_p = m_p.float()
|
466 |
-
logs_p = logs_p.float()
|
467 |
-
z_mask = z_mask.float()
|
468 |
-
|
469 |
-
kl = logs_p - logs_q - 0.5
|
470 |
-
kl += 0.5 * ((z_p - m_p) ** 2) * torch.exp(-2.0 * logs_p)
|
471 |
-
kl = torch.sum(kl * z_mask)
|
472 |
-
l = kl / torch.sum(z_mask)
|
473 |
-
|
474 |
-
return l
|
475 |
-
|
476 |
-
|
477 |
-
class TextAudioLoaderMultiNSFsid(torch.utils.data.Dataset):
|
478 |
-
def __init__(self, hparams):
|
479 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
480 |
-
self.max_wav_value = hparams.max_wav_value
|
481 |
-
self.sample_rate = hparams.sample_rate
|
482 |
-
self.filter_length = hparams.filter_length
|
483 |
-
self.hop_length = hparams.hop_length
|
484 |
-
self.win_length = hparams.win_length
|
485 |
-
self.sample_rate = hparams.sample_rate
|
486 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
487 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
488 |
-
self._filter()
|
489 |
-
|
490 |
-
def _filter(self):
|
491 |
-
audiopaths_and_text_new = []
|
492 |
-
lengths = []
|
493 |
-
|
494 |
-
for audiopath, text, pitch, pitchf, dv in self.audiopaths_and_text:
|
495 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
496 |
-
audiopaths_and_text_new.append([audiopath, text, pitch, pitchf, dv])
|
497 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
498 |
-
|
499 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
500 |
-
self.lengths = lengths
|
501 |
-
|
502 |
-
def get_sid(self, sid):
|
503 |
-
try:
|
504 |
-
sid = torch.LongTensor([int(sid)])
|
505 |
-
except ValueError as e:
|
506 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
507 |
-
sid = torch.LongTensor([0])
|
508 |
-
|
509 |
-
return sid
|
510 |
-
|
511 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
512 |
-
file = audiopath_and_text[0]
|
513 |
-
phone = audiopath_and_text[1]
|
514 |
-
pitch = audiopath_and_text[2]
|
515 |
-
pitchf = audiopath_and_text[3]
|
516 |
-
dv = audiopath_and_text[4]
|
517 |
-
|
518 |
-
phone, pitch, pitchf = self.get_labels(phone, pitch, pitchf)
|
519 |
-
spec, wav = self.get_audio(file)
|
520 |
-
dv = self.get_sid(dv)
|
521 |
-
|
522 |
-
len_phone = phone.size()[0]
|
523 |
-
len_spec = spec.size()[-1]
|
524 |
-
if len_phone != len_spec:
|
525 |
-
len_min = min(len_phone, len_spec)
|
526 |
-
len_wav = len_min * self.hop_length
|
527 |
-
|
528 |
-
spec = spec[:, :len_min]
|
529 |
-
wav = wav[:, :len_wav]
|
530 |
-
|
531 |
-
phone = phone[:len_min, :]
|
532 |
-
pitch = pitch[:len_min]
|
533 |
-
pitchf = pitchf[:len_min]
|
534 |
-
|
535 |
-
return (spec, wav, phone, pitch, pitchf, dv)
|
536 |
-
|
537 |
-
def get_labels(self, phone, pitch, pitchf):
|
538 |
-
phone = np.load(phone)
|
539 |
-
phone = np.repeat(phone, 2, axis=0)
|
540 |
-
|
541 |
-
pitch = np.load(pitch)
|
542 |
-
pitchf = np.load(pitchf)
|
543 |
-
|
544 |
-
n_num = min(phone.shape[0], 900)
|
545 |
-
phone = phone[:n_num, :]
|
546 |
-
|
547 |
-
pitch = pitch[:n_num]
|
548 |
-
pitchf = pitchf[:n_num]
|
549 |
-
|
550 |
-
phone = torch.FloatTensor(phone)
|
551 |
-
|
552 |
-
pitch = torch.LongTensor(pitch)
|
553 |
-
pitchf = torch.FloatTensor(pitchf)
|
554 |
-
|
555 |
-
return phone, pitch, pitchf
|
556 |
-
|
557 |
-
def get_audio(self, filename):
|
558 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
559 |
-
|
560 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
561 |
-
|
562 |
-
audio_norm = audio
|
563 |
-
audio_norm = audio_norm.unsqueeze(0)
|
564 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
565 |
-
|
566 |
-
if os.path.exists(spec_filename):
|
567 |
-
try:
|
568 |
-
spec = torch.load(spec_filename)
|
569 |
-
except Exception as e:
|
570 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
571 |
-
spec = spectrogram_torch(
|
572 |
-
audio_norm,
|
573 |
-
self.filter_length,
|
574 |
-
self.hop_length,
|
575 |
-
self.win_length,
|
576 |
-
center=False,
|
577 |
-
)
|
578 |
-
spec = torch.squeeze(spec, 0)
|
579 |
-
|
580 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
581 |
-
else:
|
582 |
-
spec = spectrogram_torch(
|
583 |
-
audio_norm,
|
584 |
-
self.filter_length,
|
585 |
-
self.hop_length,
|
586 |
-
self.win_length,
|
587 |
-
center=False,
|
588 |
-
)
|
589 |
-
spec = torch.squeeze(spec, 0)
|
590 |
-
|
591 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
592 |
-
return spec, audio_norm
|
593 |
-
|
594 |
-
def __getitem__(self, index):
|
595 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
596 |
-
|
597 |
-
def __len__(self):
|
598 |
-
return len(self.audiopaths_and_text)
|
599 |
-
|
600 |
-
|
601 |
-
class TextAudioCollateMultiNSFsid:
|
602 |
-
def __init__(self, return_ids=False):
|
603 |
-
self.return_ids = return_ids
|
604 |
-
|
605 |
-
def __call__(self, batch):
|
606 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
607 |
-
|
608 |
-
max_spec_len = max([x[0].size(1) for x in batch])
|
609 |
-
max_wave_len = max([x[1].size(1) for x in batch])
|
610 |
-
|
611 |
-
spec_lengths = torch.LongTensor(len(batch))
|
612 |
-
wave_lengths = torch.LongTensor(len(batch))
|
613 |
-
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
614 |
-
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
615 |
-
|
616 |
-
spec_padded.zero_()
|
617 |
-
wave_padded.zero_()
|
618 |
-
|
619 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
620 |
-
|
621 |
-
phone_lengths = torch.LongTensor(len(batch))
|
622 |
-
phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
623 |
-
pitch_padded = torch.LongTensor(len(batch), max_phone_len)
|
624 |
-
pitchf_padded = torch.FloatTensor(len(batch), max_phone_len)
|
625 |
-
|
626 |
-
phone_padded.zero_()
|
627 |
-
pitch_padded.zero_()
|
628 |
-
pitchf_padded.zero_()
|
629 |
-
sid = torch.LongTensor(len(batch))
|
630 |
-
|
631 |
-
for i in range(len(ids_sorted_decreasing)):
|
632 |
-
row = batch[ids_sorted_decreasing[i]]
|
633 |
-
|
634 |
-
spec = row[0]
|
635 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
636 |
-
spec_lengths[i] = spec.size(1)
|
637 |
-
|
638 |
-
wave = row[1]
|
639 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
640 |
-
wave_lengths[i] = wave.size(1)
|
641 |
-
|
642 |
-
phone = row[2]
|
643 |
-
phone_padded[i, : phone.size(0), :] = phone
|
644 |
-
phone_lengths[i] = phone.size(0)
|
645 |
-
|
646 |
-
pitch = row[3]
|
647 |
-
pitch_padded[i, : pitch.size(0)] = pitch
|
648 |
-
pitchf = row[4]
|
649 |
-
pitchf_padded[i, : pitchf.size(0)] = pitchf
|
650 |
-
|
651 |
-
sid[i] = row[5]
|
652 |
-
|
653 |
-
return (
|
654 |
-
phone_padded,
|
655 |
-
phone_lengths,
|
656 |
-
pitch_padded,
|
657 |
-
pitchf_padded,
|
658 |
-
spec_padded,
|
659 |
-
spec_lengths,
|
660 |
-
wave_padded,
|
661 |
-
wave_lengths,
|
662 |
-
sid,
|
663 |
-
)
|
664 |
-
|
665 |
-
|
666 |
-
class TextAudioLoader(torch.utils.data.Dataset):
|
667 |
-
def __init__(self, hparams):
|
668 |
-
self.audiopaths_and_text = load_filepaths_and_text(hparams.training_files)
|
669 |
-
self.max_wav_value = hparams.max_wav_value
|
670 |
-
self.sample_rate = hparams.sample_rate
|
671 |
-
self.filter_length = hparams.filter_length
|
672 |
-
self.hop_length = hparams.hop_length
|
673 |
-
self.win_length = hparams.win_length
|
674 |
-
self.sample_rate = hparams.sample_rate
|
675 |
-
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
676 |
-
self.max_text_len = getattr(hparams, "max_text_len", 5000)
|
677 |
-
self._filter()
|
678 |
-
|
679 |
-
def _filter(self):
|
680 |
-
audiopaths_and_text_new = []
|
681 |
-
lengths = []
|
682 |
-
|
683 |
-
for entry in self.audiopaths_and_text:
|
684 |
-
if len(entry) >= 3:
|
685 |
-
audiopath, text, dv = entry[:3]
|
686 |
-
|
687 |
-
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
688 |
-
audiopaths_and_text_new.append([audiopath, text, dv])
|
689 |
-
lengths.append(os.path.getsize(audiopath) // (3 * self.hop_length))
|
690 |
-
|
691 |
-
self.audiopaths_and_text = audiopaths_and_text_new
|
692 |
-
self.lengths = lengths
|
693 |
-
|
694 |
-
def get_sid(self, sid):
|
695 |
-
try:
|
696 |
-
sid = torch.LongTensor([int(sid)])
|
697 |
-
except ValueError as e:
|
698 |
-
logger.error(translations["sid_error"].format(sid=sid, e=e))
|
699 |
-
sid = torch.LongTensor([0])
|
700 |
-
|
701 |
-
return sid
|
702 |
-
|
703 |
-
def get_audio_text_pair(self, audiopath_and_text):
|
704 |
-
file = audiopath_and_text[0]
|
705 |
-
phone = audiopath_and_text[1]
|
706 |
-
dv = audiopath_and_text[2]
|
707 |
-
|
708 |
-
phone = self.get_labels(phone)
|
709 |
-
spec, wav = self.get_audio(file)
|
710 |
-
dv = self.get_sid(dv)
|
711 |
-
|
712 |
-
len_phone = phone.size()[0]
|
713 |
-
len_spec = spec.size()[-1]
|
714 |
-
|
715 |
-
if len_phone != len_spec:
|
716 |
-
len_min = min(len_phone, len_spec)
|
717 |
-
len_wav = len_min * self.hop_length
|
718 |
-
spec = spec[:, :len_min]
|
719 |
-
wav = wav[:, :len_wav]
|
720 |
-
phone = phone[:len_min, :]
|
721 |
-
|
722 |
-
return (spec, wav, phone, dv)
|
723 |
-
|
724 |
-
def get_labels(self, phone):
|
725 |
-
phone = np.load(phone)
|
726 |
-
phone = np.repeat(phone, 2, axis=0)
|
727 |
-
n_num = min(phone.shape[0], 900)
|
728 |
-
phone = phone[:n_num, :]
|
729 |
-
phone = torch.FloatTensor(phone)
|
730 |
-
return phone
|
731 |
-
|
732 |
-
def get_audio(self, filename):
|
733 |
-
audio, sample_rate = load_wav_to_torch(filename)
|
734 |
-
|
735 |
-
if sample_rate != self.sample_rate: raise ValueError(translations["sr_does_not_match"].format(sample_rate=sample_rate, sample_rate2=self.sample_rate))
|
736 |
-
|
737 |
-
audio_norm = audio
|
738 |
-
audio_norm = audio_norm.unsqueeze(0)
|
739 |
-
|
740 |
-
spec_filename = filename.replace(".wav", ".spec.pt")
|
741 |
-
|
742 |
-
if os.path.exists(spec_filename):
|
743 |
-
try:
|
744 |
-
spec = torch.load(spec_filename)
|
745 |
-
except Exception as e:
|
746 |
-
logger.error(translations["spec_error"].format(spec_filename=spec_filename, e=e))
|
747 |
-
spec = spectrogram_torch(
|
748 |
-
audio_norm,
|
749 |
-
self.filter_length,
|
750 |
-
self.hop_length,
|
751 |
-
self.win_length,
|
752 |
-
center=False,
|
753 |
-
)
|
754 |
-
spec = torch.squeeze(spec, 0)
|
755 |
-
|
756 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
757 |
-
else:
|
758 |
-
spec = spectrogram_torch(
|
759 |
-
audio_norm,
|
760 |
-
self.filter_length,
|
761 |
-
self.hop_length,
|
762 |
-
self.win_length,
|
763 |
-
center=False,
|
764 |
-
)
|
765 |
-
spec = torch.squeeze(spec, 0)
|
766 |
-
|
767 |
-
torch.save(spec, spec_filename, _use_new_zipfile_serialization=False)
|
768 |
-
return spec, audio_norm
|
769 |
-
|
770 |
-
def __getitem__(self, index):
|
771 |
-
return self.get_audio_text_pair(self.audiopaths_and_text[index])
|
772 |
-
|
773 |
-
def __len__(self):
|
774 |
-
return len(self.audiopaths_and_text)
|
775 |
-
|
776 |
-
|
777 |
-
class TextAudioCollate:
|
778 |
-
def __init__(self, return_ids=False):
|
779 |
-
self.return_ids = return_ids
|
780 |
-
|
781 |
-
def __call__(self, batch):
|
782 |
-
_, ids_sorted_decreasing = torch.sort(torch.LongTensor([x[0].size(1) for x in batch]), dim=0, descending=True)
|
783 |
-
|
784 |
-
max_spec_len = max([x[0].size(1) for x in batch])
|
785 |
-
max_wave_len = max([x[1].size(1) for x in batch])
|
786 |
-
|
787 |
-
spec_lengths = torch.LongTensor(len(batch))
|
788 |
-
wave_lengths = torch.LongTensor(len(batch))
|
789 |
-
|
790 |
-
spec_padded = torch.FloatTensor(len(batch), batch[0][0].size(0), max_spec_len)
|
791 |
-
wave_padded = torch.FloatTensor(len(batch), 1, max_wave_len)
|
792 |
-
|
793 |
-
spec_padded.zero_()
|
794 |
-
wave_padded.zero_()
|
795 |
-
|
796 |
-
max_phone_len = max([x[2].size(0) for x in batch])
|
797 |
-
|
798 |
-
phone_lengths = torch.LongTensor(len(batch))
|
799 |
-
phone_padded = torch.FloatTensor(len(batch), max_phone_len, batch[0][2].shape[1])
|
800 |
-
|
801 |
-
phone_padded.zero_()
|
802 |
-
sid = torch.LongTensor(len(batch))
|
803 |
-
|
804 |
-
for i in range(len(ids_sorted_decreasing)):
|
805 |
-
row = batch[ids_sorted_decreasing[i]]
|
806 |
-
|
807 |
-
spec = row[0]
|
808 |
-
spec_padded[i, :, : spec.size(1)] = spec
|
809 |
-
spec_lengths[i] = spec.size(1)
|
810 |
-
|
811 |
-
wave = row[1]
|
812 |
-
wave_padded[i, :, : wave.size(1)] = wave
|
813 |
-
wave_lengths[i] = wave.size(1)
|
814 |
-
|
815 |
-
phone = row[2]
|
816 |
-
phone_padded[i, : phone.size(0), :] = phone
|
817 |
-
phone_lengths[i] = phone.size(0)
|
818 |
-
|
819 |
-
sid[i] = row[3]
|
820 |
-
|
821 |
-
return (
|
822 |
-
phone_padded,
|
823 |
-
phone_lengths,
|
824 |
-
spec_padded,
|
825 |
-
spec_lengths,
|
826 |
-
wave_padded,
|
827 |
-
wave_lengths,
|
828 |
-
sid,
|
829 |
-
)
|
830 |
-
|
831 |
-
|
832 |
-
class DistributedBucketSampler(torch.utils.data.distributed.DistributedSampler):
|
833 |
-
def __init__(self, dataset, batch_size, boundaries, num_replicas=None, rank=None, shuffle=True):
|
834 |
-
super().__init__(dataset, num_replicas=num_replicas, rank=rank, shuffle=shuffle)
|
835 |
-
self.lengths = dataset.lengths
|
836 |
-
self.batch_size = batch_size
|
837 |
-
self.boundaries = boundaries
|
838 |
-
self.buckets, self.num_samples_per_bucket = self._create_buckets()
|
839 |
-
self.total_size = sum(self.num_samples_per_bucket)
|
840 |
-
self.num_samples = self.total_size // self.num_replicas
|
841 |
-
|
842 |
-
def _create_buckets(self):
|
843 |
-
buckets = [[] for _ in range(len(self.boundaries) - 1)]
|
844 |
-
|
845 |
-
for i in range(len(self.lengths)):
|
846 |
-
length = self.lengths[i]
|
847 |
-
idx_bucket = self._bisect(length)
|
848 |
-
if idx_bucket != -1: buckets[idx_bucket].append(i)
|
849 |
-
|
850 |
-
for i in range(len(buckets) - 1, -1, -1):
|
851 |
-
if len(buckets[i]) == 0:
|
852 |
-
buckets.pop(i)
|
853 |
-
self.boundaries.pop(i + 1)
|
854 |
-
|
855 |
-
num_samples_per_bucket = []
|
856 |
-
|
857 |
-
for i in range(len(buckets)):
|
858 |
-
len_bucket = len(buckets[i])
|
859 |
-
total_batch_size = self.num_replicas * self.batch_size
|
860 |
-
rem = (total_batch_size - (len_bucket % total_batch_size)) % total_batch_size
|
861 |
-
num_samples_per_bucket.append(len_bucket + rem)
|
862 |
-
|
863 |
-
return buckets, num_samples_per_bucket
|
864 |
-
|
865 |
-
def __iter__(self):
|
866 |
-
g = torch.Generator()
|
867 |
-
g.manual_seed(self.epoch)
|
868 |
-
|
869 |
-
indices = []
|
870 |
-
|
871 |
-
if self.shuffle:
|
872 |
-
for bucket in self.buckets:
|
873 |
-
indices.append(torch.randperm(len(bucket), generator=g).tolist())
|
874 |
-
else:
|
875 |
-
for bucket in self.buckets:
|
876 |
-
indices.append(list(range(len(bucket))))
|
877 |
-
|
878 |
-
batches = []
|
879 |
-
|
880 |
-
for i in range(len(self.buckets)):
|
881 |
-
bucket = self.buckets[i]
|
882 |
-
len_bucket = len(bucket)
|
883 |
-
ids_bucket = indices[i]
|
884 |
-
num_samples_bucket = self.num_samples_per_bucket[i]
|
885 |
-
|
886 |
-
rem = num_samples_bucket - len_bucket
|
887 |
-
|
888 |
-
ids_bucket = (ids_bucket + ids_bucket * (rem // len_bucket) + ids_bucket[: (rem % len_bucket)])
|
889 |
-
ids_bucket = ids_bucket[self.rank :: self.num_replicas]
|
890 |
-
|
891 |
-
for j in range(len(ids_bucket) // self.batch_size):
|
892 |
-
batch = [bucket[idx] for idx in ids_bucket[j * self.batch_size : (j + 1) * self.batch_size]]
|
893 |
-
batches.append(batch)
|
894 |
-
|
895 |
-
if self.shuffle:
|
896 |
-
batch_ids = torch.randperm(len(batches), generator=g).tolist()
|
897 |
-
batches = [batches[i] for i in batch_ids]
|
898 |
-
|
899 |
-
self.batches = batches
|
900 |
-
|
901 |
-
assert len(self.batches) * self.batch_size == self.num_samples
|
902 |
-
return iter(self.batches)
|
903 |
-
|
904 |
-
def _bisect(self, x, lo=0, hi=None):
|
905 |
-
if hi is None: hi = len(self.boundaries) - 1
|
906 |
-
|
907 |
-
if hi > lo:
|
908 |
-
mid = (hi + lo) // 2
|
909 |
-
|
910 |
-
if self.boundaries[mid] < x and x <= self.boundaries[mid + 1]: return mid
|
911 |
-
elif x <= self.boundaries[mid]: return self._bisect(x, lo, mid)
|
912 |
-
else: return self._bisect(x, mid + 1, hi)
|
913 |
-
|
914 |
-
else: return -1
|
915 |
-
|
916 |
-
def __len__(self):
|
917 |
-
return self.num_samples // self.batch_size
|
918 |
-
|
919 |
-
|
920 |
-
class MultiPeriodDiscriminator(torch.nn.Module):
|
921 |
-
def __init__(self, use_spectral_norm=False):
|
922 |
-
super(MultiPeriodDiscriminator, self).__init__()
|
923 |
-
periods = [2, 3, 5, 7, 11, 17]
|
924 |
-
self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods])
|
925 |
-
|
926 |
-
def forward(self, y, y_hat):
|
927 |
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
928 |
-
|
929 |
-
for d in self.discriminators:
|
930 |
-
y_d_r, fmap_r = d(y)
|
931 |
-
y_d_g, fmap_g = d(y_hat)
|
932 |
-
y_d_rs.append(y_d_r)
|
933 |
-
y_d_gs.append(y_d_g)
|
934 |
-
fmap_rs.append(fmap_r)
|
935 |
-
fmap_gs.append(fmap_g)
|
936 |
-
|
937 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
938 |
-
|
939 |
-
|
940 |
-
class MultiPeriodDiscriminatorV2(torch.nn.Module):
|
941 |
-
def __init__(self, use_spectral_norm=False):
|
942 |
-
super(MultiPeriodDiscriminatorV2, self).__init__()
|
943 |
-
periods = [2, 3, 5, 7, 11, 17, 23, 37]
|
944 |
-
self.discriminators = torch.nn.ModuleList([DiscriminatorS(use_spectral_norm=use_spectral_norm)] + [DiscriminatorP(p, use_spectral_norm=use_spectral_norm) for p in periods])
|
945 |
-
|
946 |
-
def forward(self, y, y_hat):
|
947 |
-
y_d_rs, y_d_gs, fmap_rs, fmap_gs = [], [], [], []
|
948 |
-
|
949 |
-
for d in self.discriminators:
|
950 |
-
y_d_r, fmap_r = d(y)
|
951 |
-
y_d_g, fmap_g = d(y_hat)
|
952 |
-
y_d_rs.append(y_d_r)
|
953 |
-
y_d_gs.append(y_d_g)
|
954 |
-
fmap_rs.append(fmap_r)
|
955 |
-
fmap_gs.append(fmap_g)
|
956 |
-
|
957 |
-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
958 |
-
|
959 |
-
|
960 |
-
class DiscriminatorS(torch.nn.Module):
|
961 |
-
def __init__(self, use_spectral_norm=False):
|
962 |
-
super(DiscriminatorS, self).__init__()
|
963 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
964 |
-
|
965 |
-
self.convs = torch.nn.ModuleList([norm_f(torch.nn.Conv1d(1, 16, 15, 1, padding=7)), norm_f(torch.nn.Conv1d(16, 64, 41, 4, groups=4, padding=20)), norm_f(torch.nn.Conv1d(64, 256, 41, 4, groups=16, padding=20)), norm_f(torch.nn.Conv1d(256, 1024, 41, 4, groups=64, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 41, 4, groups=256, padding=20)), norm_f(torch.nn.Conv1d(1024, 1024, 5, 1, padding=2))])
|
966 |
-
self.conv_post = norm_f(torch.nn.Conv1d(1024, 1, 3, 1, padding=1))
|
967 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
968 |
-
|
969 |
-
def forward(self, x):
|
970 |
-
fmap = []
|
971 |
-
|
972 |
-
for conv in self.convs:
|
973 |
-
x = self.lrelu(conv(x))
|
974 |
-
fmap.append(x)
|
975 |
-
|
976 |
-
x = self.conv_post(x)
|
977 |
-
fmap.append(x)
|
978 |
-
x = torch.flatten(x, 1, -1)
|
979 |
-
|
980 |
-
return x, fmap
|
981 |
-
|
982 |
-
|
983 |
-
class DiscriminatorP(torch.nn.Module):
|
984 |
-
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
985 |
-
super(DiscriminatorP, self).__init__()
|
986 |
-
self.period = period
|
987 |
-
norm_f = spectral_norm if use_spectral_norm else weight_norm
|
988 |
-
|
989 |
-
in_channels = [1, 32, 128, 512, 1024]
|
990 |
-
out_channels = [32, 128, 512, 1024, 1024]
|
991 |
-
|
992 |
-
self.convs = torch.nn.ModuleList(
|
993 |
-
[
|
994 |
-
norm_f(
|
995 |
-
torch.nn.Conv2d(
|
996 |
-
in_ch,
|
997 |
-
out_ch,
|
998 |
-
(kernel_size, 1),
|
999 |
-
(stride, 1),
|
1000 |
-
padding=(get_padding(kernel_size, 1), 0),
|
1001 |
-
)
|
1002 |
-
)
|
1003 |
-
for in_ch, out_ch in zip(in_channels, out_channels)
|
1004 |
-
]
|
1005 |
-
)
|
1006 |
-
|
1007 |
-
self.conv_post = norm_f(torch.nn.Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
1008 |
-
self.lrelu = torch.nn.LeakyReLU(LRELU_SLOPE)
|
1009 |
-
|
1010 |
-
def forward(self, x):
|
1011 |
-
fmap = []
|
1012 |
-
b, c, t = x.shape
|
1013 |
-
|
1014 |
-
if t % self.period != 0:
|
1015 |
-
n_pad = self.period - (t % self.period)
|
1016 |
-
x = torch.nn.functional.pad(x, (0, n_pad), "reflect")
|
1017 |
-
|
1018 |
-
x = x.view(b, c, -1, self.period)
|
1019 |
-
|
1020 |
-
for conv in self.convs:
|
1021 |
-
x = self.lrelu(conv(x))
|
1022 |
-
fmap.append(x)
|
1023 |
-
|
1024 |
-
x = self.conv_post(x)
|
1025 |
-
fmap.append(x)
|
1026 |
-
x = torch.flatten(x, 1, -1)
|
1027 |
-
return x, fmap
|
1028 |
-
|
1029 |
-
|
1030 |
-
class EpochRecorder:
|
1031 |
-
def __init__(self):
|
1032 |
-
self.last_time = ttime()
|
1033 |
-
|
1034 |
-
def record(self):
|
1035 |
-
now_time = ttime()
|
1036 |
-
elapsed_time = now_time - self.last_time
|
1037 |
-
self.last_time = now_time
|
1038 |
-
elapsed_time = round(elapsed_time, 1)
|
1039 |
-
elapsed_time_str = str(datetime.timedelta(seconds=int(elapsed_time)))
|
1040 |
-
current_time = datetime.datetime.now().strftime("%H:%M:%S")
|
1041 |
-
return translations["time_or_speed_training"].format(current_time=current_time, elapsed_time_str=elapsed_time_str)
|
1042 |
-
|
1043 |
-
|
1044 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
1045 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
1046 |
-
|
1047 |
-
|
1048 |
-
def dynamic_range_decompression_torch(x, C=1):
|
1049 |
-
return torch.exp(x) / C
|
1050 |
-
|
1051 |
-
|
1052 |
-
def spectral_normalize_torch(magnitudes):
|
1053 |
-
return dynamic_range_compression_torch(magnitudes)
|
1054 |
-
|
1055 |
-
|
1056 |
-
def spectral_de_normalize_torch(magnitudes):
|
1057 |
-
return dynamic_range_decompression_torch(magnitudes)
|
1058 |
-
|
1059 |
-
|
1060 |
-
mel_basis = {}
|
1061 |
-
hann_window = {}
|
1062 |
-
|
1063 |
-
|
1064 |
-
def spectrogram_torch(y, n_fft, hop_size, win_size, center=False):
|
1065 |
-
global hann_window
|
1066 |
-
dtype_device = str(y.dtype) + "_" + str(y.device)
|
1067 |
-
wnsize_dtype_device = str(win_size) + "_" + dtype_device
|
1068 |
-
if wnsize_dtype_device not in hann_window: hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(dtype=y.dtype, device=y.device)
|
1069 |
-
|
1070 |
-
y = torch.nn.functional.pad(y.unsqueeze(1), (int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)), mode="reflect")
|
1071 |
-
|
1072 |
-
y = y.squeeze(1)
|
1073 |
-
|
1074 |
-
spec = torch.stft(y, n_fft, hop_length=hop_size, win_length=win_size, window=hann_window[wnsize_dtype_device], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
1075 |
-
|
1076 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + 1e-6)
|
1077 |
-
return spec
|
1078 |
-
|
1079 |
-
|
1080 |
-
def spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax):
|
1081 |
-
global mel_basis
|
1082 |
-
dtype_device = str(spec.dtype) + "_" + str(spec.device)
|
1083 |
-
fmax_dtype_device = str(fmax) + "_" + dtype_device
|
1084 |
-
|
1085 |
-
if fmax_dtype_device not in mel_basis:
|
1086 |
-
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax)
|
1087 |
-
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(dtype=spec.dtype, device=spec.device)
|
1088 |
-
|
1089 |
-
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
|
1090 |
-
melspec = spectral_normalize_torch(melspec)
|
1091 |
-
return melspec
|
1092 |
-
|
1093 |
-
|
1094 |
-
def mel_spectrogram_torch(y, n_fft, num_mels, sample_rate, hop_size, win_size, fmin, fmax, center=False):
|
1095 |
-
spec = spectrogram_torch(y, n_fft, hop_size, win_size, center)
|
1096 |
-
|
1097 |
-
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sample_rate, fmin, fmax)
|
1098 |
-
|
1099 |
-
return melspec
|
1100 |
-
|
1101 |
-
|
1102 |
-
def replace_keys_in_dict(d, old_key_part, new_key_part):
|
1103 |
-
updated_dict = OrderedDict() if isinstance(d, OrderedDict) else {}
|
1104 |
-
|
1105 |
-
for key, value in d.items():
|
1106 |
-
new_key = (key.replace(old_key_part, new_key_part) if isinstance(key, str) else key)
|
1107 |
-
updated_dict[new_key] = (replace_keys_in_dict(value, old_key_part, new_key_part) if isinstance(value, dict) else value)
|
1108 |
-
|
1109 |
-
return updated_dict
|
1110 |
-
|
1111 |
-
|
1112 |
-
def extract_model(ckpt, sr, pitch_guidance, name, model_dir, epoch, step, version, hps, model_author):
|
1113 |
-
try:
|
1114 |
-
logger.info(translations["savemodel"].format(model_dir=model_dir, epoch=epoch, step=step))
|
1115 |
-
|
1116 |
-
model_dir_path = os.path.join("assets", "weights")
|
1117 |
-
|
1118 |
-
if "best_epoch" in model_dir: pth_file = f"{name}_{epoch}e_{step}s_best_epoch.pth"
|
1119 |
-
else: pth_file = f"{name}_{epoch}e_{step}s.pth"
|
1120 |
-
|
1121 |
-
pth_file_old_version_path = os.path.join(model_dir_path, f"{pth_file}_old_version.pth")
|
1122 |
-
|
1123 |
-
opt = OrderedDict(weight={key: value.half() for key, value in ckpt.items() if "enc_q" not in key})
|
1124 |
-
|
1125 |
-
opt["config"] = [
|
1126 |
-
hps.data.filter_length // 2 + 1,
|
1127 |
-
32,
|
1128 |
-
hps.model.inter_channels,
|
1129 |
-
hps.model.hidden_channels,
|
1130 |
-
hps.model.filter_channels,
|
1131 |
-
hps.model.n_heads,
|
1132 |
-
hps.model.n_layers,
|
1133 |
-
hps.model.kernel_size,
|
1134 |
-
hps.model.p_dropout,
|
1135 |
-
hps.model.resblock,
|
1136 |
-
hps.model.resblock_kernel_sizes,
|
1137 |
-
hps.model.resblock_dilation_sizes,
|
1138 |
-
hps.model.upsample_rates,
|
1139 |
-
hps.model.upsample_initial_channel,
|
1140 |
-
hps.model.upsample_kernel_sizes,
|
1141 |
-
hps.model.spk_embed_dim,
|
1142 |
-
hps.model.gin_channels,
|
1143 |
-
hps.data.sample_rate,
|
1144 |
-
]
|
1145 |
-
|
1146 |
-
opt["epoch"] = f"{epoch}epoch"
|
1147 |
-
opt["step"] = step
|
1148 |
-
opt["sr"] = sr
|
1149 |
-
opt["f0"] = int(pitch_guidance)
|
1150 |
-
opt["version"] = version
|
1151 |
-
opt["creation_date"] = datetime.datetime.now().isoformat()
|
1152 |
-
|
1153 |
-
hash_input = f"{str(ckpt)} {epoch} {step} {datetime.datetime.now().isoformat()}"
|
1154 |
-
model_hash = hashlib.sha256(hash_input.encode()).hexdigest()
|
1155 |
-
opt["model_hash"] = model_hash
|
1156 |
-
opt["model_name"] = name
|
1157 |
-
opt["author"] = model_author
|
1158 |
-
|
1159 |
-
torch.save(opt, os.path.join(model_dir_path, pth_file))
|
1160 |
-
|
1161 |
-
model = torch.load(model_dir, map_location=torch.device("cpu"))
|
1162 |
-
torch.save(replace_keys_in_dict(replace_keys_in_dict(model, ".parametrizations.weight.original1", ".weight_v"), ".parametrizations.weight.original0", ".weight_g"), pth_file_old_version_path)
|
1163 |
-
|
1164 |
-
os.remove(model_dir)
|
1165 |
-
os.rename(pth_file_old_version_path, model_dir)
|
1166 |
-
|
1167 |
-
except Exception as e:
|
1168 |
-
logger.error(f"{translations['extract_model_error']}: {e}")
|
1169 |
-
|
1170 |
-
|
1171 |
-
def run(rank, n_gpus, experiment_dir, pretrainG, pretrainD, pitch_guidance, custom_total_epoch, custom_save_every_weights, config, device, model_author):
|
1172 |
-
global global_step
|
1173 |
-
|
1174 |
-
if rank == 0:
|
1175 |
-
writer = SummaryWriter(log_dir=experiment_dir)
|
1176 |
-
writer_eval = SummaryWriter(log_dir=os.path.join(experiment_dir, "eval"))
|
1177 |
-
|
1178 |
-
dist.init_process_group(backend="gloo", init_method="env://", world_size=n_gpus, rank=rank)
|
1179 |
-
torch.manual_seed(config.train.seed)
|
1180 |
-
|
1181 |
-
if torch.cuda.is_available(): torch.cuda.set_device(rank)
|
1182 |
-
|
1183 |
-
train_dataset = TextAudioLoaderMultiNSFsid(config.data)
|
1184 |
-
|
1185 |
-
train_sampler = DistributedBucketSampler(train_dataset, batch_size * n_gpus, [100, 200, 300, 400, 500, 600, 700, 800, 900], num_replicas=n_gpus, rank=rank, shuffle=True)
|
1186 |
-
|
1187 |
-
collate_fn = TextAudioCollateMultiNSFsid()
|
1188 |
-
|
1189 |
-
train_loader = DataLoader(train_dataset, num_workers=4, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler, persistent_workers=True, prefetch_factor=8)
|
1190 |
-
|
1191 |
-
net_g = Synthesizer(config.data.filter_length // 2 + 1, config.train.segment_size // config.data.hop_length, **config.model, use_f0=pitch_guidance == True, is_half=config.train.fp16_run and device.type == "cuda", sr=sample_rate).to(device)
|
1192 |
-
|
1193 |
-
if torch.cuda.is_available(): net_g = net_g.cuda(rank)
|
1194 |
-
|
1195 |
-
if version == "v1": net_d = MultiPeriodDiscriminator(config.model.use_spectral_norm)
|
1196 |
-
else: net_d = MultiPeriodDiscriminatorV2(config.model.use_spectral_norm)
|
1197 |
-
|
1198 |
-
if torch.cuda.is_available(): net_d = net_d.cuda(rank)
|
1199 |
-
|
1200 |
-
optim_g = torch.optim.AdamW(net_g.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
|
1201 |
-
optim_d = torch.optim.AdamW(net_d.parameters(), config.train.learning_rate, betas=config.train.betas, eps=config.train.eps)
|
1202 |
-
|
1203 |
-
if torch.cuda.is_available():
|
1204 |
-
net_g = DDP(net_g, device_ids=[rank])
|
1205 |
-
net_d = DDP(net_d, device_ids=[rank])
|
1206 |
-
else:
|
1207 |
-
net_g = DDP(net_g)
|
1208 |
-
net_d = DDP(net_d)
|
1209 |
-
|
1210 |
-
try:
|
1211 |
-
logger.info(translations["start_training"])
|
1212 |
-
|
1213 |
-
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "D_*.pth"), net_d, optim_d)
|
1214 |
-
_, _, _, epoch_str = load_checkpoint(latest_checkpoint_path(experiment_dir, "G_*.pth"), net_g, optim_g)
|
1215 |
-
|
1216 |
-
epoch_str += 1
|
1217 |
-
global_step = (epoch_str - 1) * len(train_loader)
|
1218 |
-
|
1219 |
-
except:
|
1220 |
-
epoch_str = 1
|
1221 |
-
global_step = 0
|
1222 |
-
|
1223 |
-
if pretrainG != "":
|
1224 |
-
if rank == 0: logger.info(translations["import_pretrain"].format(dg="G", pretrain=pretrainG))
|
1225 |
-
|
1226 |
-
if hasattr(net_g, "module"): net_g.module.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
1227 |
-
else: net_g.load_state_dict(torch.load(pretrainG, map_location="cpu")["model"])
|
1228 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="G"))
|
1229 |
-
|
1230 |
-
if pretrainD != "":
|
1231 |
-
if rank == 0: logger.info(translations["import_pretrain"].format(dg="D", pretrain=pretrainD))
|
1232 |
-
|
1233 |
-
if hasattr(net_d, "module"): net_d.module.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
1234 |
-
else: net_d.load_state_dict(torch.load(pretrainD, map_location="cpu")["model"])
|
1235 |
-
else: logger.warning(translations["not_using_pretrain"].format(dg="D"))
|
1236 |
-
|
1237 |
-
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
|
1238 |
-
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=config.train.lr_decay, last_epoch=epoch_str - 2)
|
1239 |
-
|
1240 |
-
optim_d.step()
|
1241 |
-
optim_g.step()
|
1242 |
-
|
1243 |
-
scaler = GradScaler(enabled=config.train.fp16_run)
|
1244 |
-
|
1245 |
-
cache = []
|
1246 |
-
|
1247 |
-
for info in train_loader:
|
1248 |
-
phone, phone_lengths, pitch, pitchf, _, _, _, _, sid = info
|
1249 |
-
reference = (
|
1250 |
-
phone.to(device),
|
1251 |
-
phone_lengths.to(device),
|
1252 |
-
pitch.to(device) if pitch_guidance else None,
|
1253 |
-
pitchf.to(device) if pitch_guidance else None,
|
1254 |
-
sid.to(device),
|
1255 |
-
)
|
1256 |
-
break
|
1257 |
-
|
1258 |
-
for epoch in range(epoch_str, total_epoch + 1):
|
1259 |
-
if rank == 0: train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, [train_loader, None], [writer, writer_eval], cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author)
|
1260 |
-
else: train_and_evaluate(rank, epoch, config, [net_g, net_d], [optim_g, optim_d], scaler, [train_loader, None], None, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author)
|
1261 |
-
|
1262 |
-
scheduler_g.step()
|
1263 |
-
scheduler_d.step()
|
1264 |
-
|
1265 |
-
|
1266 |
-
def train_and_evaluate(rank, epoch, hps, nets, optims, scaler, loaders, writers, cache, custom_save_every_weights, custom_total_epoch, device, reference, model_author):
|
1267 |
-
global global_step, lowest_value, loss_disc, consecutive_increases_gen, consecutive_increases_disc
|
1268 |
-
|
1269 |
-
if epoch == 1:
|
1270 |
-
lowest_value = {"step": 0, "value": float("inf"), "epoch": 0}
|
1271 |
-
last_loss_gen_all = 0.0
|
1272 |
-
consecutive_increases_gen = 0
|
1273 |
-
consecutive_increases_disc = 0
|
1274 |
-
|
1275 |
-
net_g, net_d = nets
|
1276 |
-
optim_g, optim_d = optims
|
1277 |
-
train_loader = loaders[0] if loaders is not None else None
|
1278 |
-
|
1279 |
-
if writers is not None: writer = writers[0]
|
1280 |
-
|
1281 |
-
train_loader.batch_sampler.set_epoch(epoch)
|
1282 |
-
|
1283 |
-
net_g.train()
|
1284 |
-
net_d.train()
|
1285 |
-
|
1286 |
-
if device.type == "cuda" and cache_data_in_gpu:
|
1287 |
-
data_iterator = cache
|
1288 |
-
|
1289 |
-
if cache == []:
|
1290 |
-
for batch_idx, info in enumerate(train_loader):
|
1291 |
-
(
|
1292 |
-
phone,
|
1293 |
-
phone_lengths,
|
1294 |
-
pitch,
|
1295 |
-
pitchf,
|
1296 |
-
spec,
|
1297 |
-
spec_lengths,
|
1298 |
-
wave,
|
1299 |
-
wave_lengths,
|
1300 |
-
sid,
|
1301 |
-
) = info
|
1302 |
-
cache.append(
|
1303 |
-
(batch_idx, (
|
1304 |
-
phone.cuda(rank, non_blocking=True),
|
1305 |
-
phone_lengths.cuda(rank, non_blocking=True),
|
1306 |
-
(pitch.cuda(rank, non_blocking=True) if pitch_guidance else None),
|
1307 |
-
(pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None),
|
1308 |
-
spec.cuda(rank, non_blocking=True),
|
1309 |
-
spec_lengths.cuda(rank, non_blocking=True),
|
1310 |
-
wave.cuda(rank, non_blocking=True),
|
1311 |
-
wave_lengths.cuda(rank, non_blocking=True),
|
1312 |
-
sid.cuda(rank, non_blocking=True),
|
1313 |
-
),
|
1314 |
-
))
|
1315 |
-
else: shuffle(cache)
|
1316 |
-
else: data_iterator = enumerate(train_loader)
|
1317 |
-
|
1318 |
-
epoch_recorder = EpochRecorder()
|
1319 |
-
|
1320 |
-
with tqdm(total=len(train_loader), leave=False) as pbar:
|
1321 |
-
for batch_idx, info in data_iterator:
|
1322 |
-
(
|
1323 |
-
phone,
|
1324 |
-
phone_lengths,
|
1325 |
-
pitch,
|
1326 |
-
pitchf,
|
1327 |
-
spec,
|
1328 |
-
spec_lengths,
|
1329 |
-
wave,
|
1330 |
-
wave_lengths,
|
1331 |
-
sid,
|
1332 |
-
) = info
|
1333 |
-
if device.type == "cuda" and not cache_data_in_gpu:
|
1334 |
-
phone = phone.cuda(rank, non_blocking=True)
|
1335 |
-
phone_lengths = phone_lengths.cuda(rank, non_blocking=True)
|
1336 |
-
pitch = pitch.cuda(rank, non_blocking=True) if pitch_guidance else None
|
1337 |
-
pitchf = (pitchf.cuda(rank, non_blocking=True) if pitch_guidance else None)
|
1338 |
-
sid = sid.cuda(rank, non_blocking=True)
|
1339 |
-
spec = spec.cuda(rank, non_blocking=True)
|
1340 |
-
spec_lengths = spec_lengths.cuda(rank, non_blocking=True)
|
1341 |
-
wave = wave.cuda(rank, non_blocking=True)
|
1342 |
-
wave_lengths = wave_lengths.cuda(rank, non_blocking=True)
|
1343 |
-
else:
|
1344 |
-
phone = phone.to(device)
|
1345 |
-
phone_lengths = phone_lengths.to(device)
|
1346 |
-
pitch = pitch.to(device) if pitch_guidance else None
|
1347 |
-
pitchf = pitchf.to(device) if pitch_guidance else None
|
1348 |
-
sid = sid.to(device)
|
1349 |
-
spec = spec.to(device)
|
1350 |
-
spec_lengths = spec_lengths.to(device)
|
1351 |
-
wave = wave.to(device)
|
1352 |
-
wave_lengths = wave_lengths.to(device)
|
1353 |
-
|
1354 |
-
use_amp = config.train.fp16_run and device.type == "cuda"
|
1355 |
-
|
1356 |
-
with autocast(enabled=use_amp):
|
1357 |
-
(
|
1358 |
-
y_hat,
|
1359 |
-
ids_slice,
|
1360 |
-
x_mask,
|
1361 |
-
z_mask,
|
1362 |
-
(z, z_p, m_p, logs_p, m_q, logs_q),
|
1363 |
-
) = net_g(phone, phone_lengths, pitch, pitchf, spec, spec_lengths, sid)
|
1364 |
-
mel = spec_to_mel_torch(
|
1365 |
-
spec,
|
1366 |
-
config.data.filter_length,
|
1367 |
-
config.data.n_mel_channels,
|
1368 |
-
config.data.sample_rate,
|
1369 |
-
config.data.mel_fmin,
|
1370 |
-
config.data.mel_fmax,
|
1371 |
-
)
|
1372 |
-
y_mel = slice_segments(mel, ids_slice, config.train.segment_size // config.data.hop_length, dim=3)
|
1373 |
-
with autocast(enabled=False):
|
1374 |
-
y_hat_mel = mel_spectrogram_torch(
|
1375 |
-
y_hat.float().squeeze(1),
|
1376 |
-
config.data.filter_length,
|
1377 |
-
config.data.n_mel_channels,
|
1378 |
-
config.data.sample_rate,
|
1379 |
-
config.data.hop_length,
|
1380 |
-
config.data.win_length,
|
1381 |
-
config.data.mel_fmin,
|
1382 |
-
config.data.mel_fmax,
|
1383 |
-
)
|
1384 |
-
if use_amp: y_hat_mel = y_hat_mel.half()
|
1385 |
-
|
1386 |
-
wave = slice_segments(wave, ids_slice * config.data.hop_length, config.train.segment_size, dim=3)
|
1387 |
-
y_d_hat_r, y_d_hat_g, _, _ = net_d(wave, y_hat.detach())
|
1388 |
-
|
1389 |
-
with autocast(enabled=False):
|
1390 |
-
loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g)
|
1391 |
-
|
1392 |
-
optim_d.zero_grad()
|
1393 |
-
scaler.scale(loss_disc).backward()
|
1394 |
-
scaler.unscale_(optim_d)
|
1395 |
-
grad_norm_d = clip_grad_value(net_d.parameters(), None)
|
1396 |
-
scaler.step(optim_d)
|
1397 |
-
|
1398 |
-
with autocast(enabled=use_amp):
|
1399 |
-
y_d_hat_r, y_d_hat_g, fmap_r, fmap_g = net_d(wave, y_hat)
|
1400 |
-
with autocast(enabled=False):
|
1401 |
-
loss_mel = F.l1_loss(y_mel, y_hat_mel) * config.train.c_mel
|
1402 |
-
loss_kl = (kl_loss(z_p, logs_q, m_p, logs_p, z_mask) * config.train.c_kl)
|
1403 |
-
loss_fm = feature_loss(fmap_r, fmap_g)
|
1404 |
-
loss_gen, losses_gen = generator_loss(y_d_hat_g)
|
1405 |
-
loss_gen_all = loss_gen + loss_fm + loss_mel + loss_kl
|
1406 |
-
|
1407 |
-
if loss_gen_all < lowest_value["value"]:
|
1408 |
-
lowest_value["value"] = loss_gen_all
|
1409 |
-
lowest_value["step"] = global_step
|
1410 |
-
lowest_value["epoch"] = epoch
|
1411 |
-
|
1412 |
-
if epoch > lowest_value["epoch"]: logger.warning(translations["training_warning"])
|
1413 |
-
|
1414 |
-
optim_g.zero_grad()
|
1415 |
-
scaler.scale(loss_gen_all).backward()
|
1416 |
-
scaler.unscale_(optim_g)
|
1417 |
-
grad_norm_g = clip_grad_value(net_g.parameters(), None)
|
1418 |
-
scaler.step(optim_g)
|
1419 |
-
scaler.update()
|
1420 |
-
|
1421 |
-
if rank == 0:
|
1422 |
-
if global_step % config.train.log_interval == 0:
|
1423 |
-
lr = optim_g.param_groups[0]["lr"]
|
1424 |
-
|
1425 |
-
if loss_mel > 75: loss_mel = 75
|
1426 |
-
if loss_kl > 9: loss_kl = 9
|
1427 |
-
|
1428 |
-
scalar_dict = {
|
1429 |
-
"loss/g/total": loss_gen_all,
|
1430 |
-
"loss/d/total": loss_disc,
|
1431 |
-
"learning_rate": lr,
|
1432 |
-
"grad_norm_d": grad_norm_d,
|
1433 |
-
"grad_norm_g": grad_norm_g,
|
1434 |
-
}
|
1435 |
-
scalar_dict.update(
|
1436 |
-
{
|
1437 |
-
"loss/g/fm": loss_fm,
|
1438 |
-
"loss/g/mel": loss_mel,
|
1439 |
-
"loss/g/kl": loss_kl,
|
1440 |
-
}
|
1441 |
-
)
|
1442 |
-
scalar_dict.update(
|
1443 |
-
{f"loss/g/{i}": v for i, v in enumerate(losses_gen)}
|
1444 |
-
)
|
1445 |
-
scalar_dict.update(
|
1446 |
-
{f"loss/d_r/{i}": v for i, v in enumerate(losses_disc_r)}
|
1447 |
-
)
|
1448 |
-
scalar_dict.update(
|
1449 |
-
{f"loss/d_g/{i}": v for i, v in enumerate(losses_disc_g)}
|
1450 |
-
)
|
1451 |
-
image_dict = {
|
1452 |
-
"slice/mel_org": plot_spectrogram_to_numpy(
|
1453 |
-
y_mel[0].data.cpu().numpy()
|
1454 |
-
),
|
1455 |
-
"slice/mel_gen": plot_spectrogram_to_numpy(
|
1456 |
-
y_hat_mel[0].data.cpu().numpy()
|
1457 |
-
),
|
1458 |
-
"all/mel": plot_spectrogram_to_numpy(mel[0].data.cpu().numpy()),
|
1459 |
-
}
|
1460 |
-
|
1461 |
-
with torch.no_grad():
|
1462 |
-
if hasattr(net_g, "module"): o, *_ = net_g.module.infer(*reference)
|
1463 |
-
else: o, *_ = net_g.infer(*reference)
|
1464 |
-
|
1465 |
-
|
1466 |
-
audio_dict = {f"gen/audio_{global_step:07d}": o[0, :, :]}
|
1467 |
-
|
1468 |
-
summarize(
|
1469 |
-
writer=writer,
|
1470 |
-
global_step=global_step,
|
1471 |
-
images=image_dict,
|
1472 |
-
scalars=scalar_dict,
|
1473 |
-
audios=audio_dict,
|
1474 |
-
audio_sample_rate=config.data.sample_rate,
|
1475 |
-
)
|
1476 |
-
|
1477 |
-
global_step += 1
|
1478 |
-
pbar.update(1)
|
1479 |
-
|
1480 |
-
def check_overtraining(smoothed_loss_history, threshold, epsilon=0.004):
|
1481 |
-
if len(smoothed_loss_history) < threshold + 1: return False
|
1482 |
-
|
1483 |
-
for i in range(-threshold, -1):
|
1484 |
-
if smoothed_loss_history[i + 1] > smoothed_loss_history[i]: return True
|
1485 |
-
if abs(smoothed_loss_history[i + 1] - smoothed_loss_history[i]) >= epsilon: return False
|
1486 |
-
|
1487 |
-
return True
|
1488 |
-
|
1489 |
-
def update_exponential_moving_average(smoothed_loss_history, new_value, smoothing=0.987):
|
1490 |
-
smoothed_value = new_value if not smoothed_loss_history else (smoothing * smoothed_loss_history[-1] + (1 - smoothing) * new_value)
|
1491 |
-
smoothed_loss_history.append(smoothed_value)
|
1492 |
-
return smoothed_value
|
1493 |
-
|
1494 |
-
def save_to_json(file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history):
|
1495 |
-
data = {
|
1496 |
-
"loss_disc_history": loss_disc_history,
|
1497 |
-
"smoothed_loss_disc_history": smoothed_loss_disc_history,
|
1498 |
-
"loss_gen_history": loss_gen_history,
|
1499 |
-
"smoothed_loss_gen_history": smoothed_loss_gen_history,
|
1500 |
-
}
|
1501 |
-
|
1502 |
-
with open(file_path, "w") as f:
|
1503 |
-
json.dump(data, f)
|
1504 |
-
|
1505 |
-
model_add = []
|
1506 |
-
model_del = []
|
1507 |
-
done = False
|
1508 |
-
|
1509 |
-
if rank == 0:
|
1510 |
-
if epoch % save_every_epoch == False:
|
1511 |
-
checkpoint_suffix = f"{2333333 if save_only_latest else global_step}.pth"
|
1512 |
-
|
1513 |
-
save_checkpoint(net_g, optim_g, config.train.learning_rate, epoch, os.path.join(experiment_dir, "G_" + checkpoint_suffix))
|
1514 |
-
save_checkpoint(net_d, optim_d, config.train.learning_rate, epoch, os.path.join(experiment_dir, "D_" + checkpoint_suffix))
|
1515 |
-
|
1516 |
-
if custom_save_every_weights: model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
1517 |
-
|
1518 |
-
if overtraining_detector and epoch > 1:
|
1519 |
-
current_loss_disc = float(loss_disc)
|
1520 |
-
loss_disc_history.append(current_loss_disc)
|
1521 |
-
|
1522 |
-
smoothed_value_disc = update_exponential_moving_average(smoothed_loss_disc_history, current_loss_disc)
|
1523 |
-
is_overtraining_disc = check_overtraining(smoothed_loss_disc_history, overtraining_threshold * 2)
|
1524 |
-
|
1525 |
-
if is_overtraining_disc: consecutive_increases_disc += 1
|
1526 |
-
else: consecutive_increases_disc = 0
|
1527 |
-
|
1528 |
-
current_loss_gen = float(lowest_value["value"])
|
1529 |
-
loss_gen_history.append(current_loss_gen)
|
1530 |
-
|
1531 |
-
smoothed_value_gen = update_exponential_moving_average(smoothed_loss_gen_history, current_loss_gen)
|
1532 |
-
is_overtraining_gen = check_overtraining(smoothed_loss_gen_history, overtraining_threshold, 0.01)
|
1533 |
-
|
1534 |
-
if is_overtraining_gen: consecutive_increases_gen += 1
|
1535 |
-
else: consecutive_increases_gen = 0
|
1536 |
-
|
1537 |
-
if epoch % save_every_epoch == 0: save_to_json(training_file_path, loss_disc_history, smoothed_loss_disc_history, loss_gen_history, smoothed_loss_gen_history)
|
1538 |
-
|
1539 |
-
if (is_overtraining_gen and consecutive_increases_gen == overtraining_threshold or is_overtraining_disc and consecutive_increases_disc == (overtraining_threshold * 2)):
|
1540 |
-
logger.info(translations["overtraining_find"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
1541 |
-
done = True
|
1542 |
-
else:
|
1543 |
-
logger.info(translations["best_epoch"].format(epoch=epoch, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
1544 |
-
|
1545 |
-
old_model_files = glob.glob(os.path.join("assets", "weights", f"{model_name}_*e_*s_best_epoch.pth"))
|
1546 |
-
|
1547 |
-
for file in old_model_files:
|
1548 |
-
model_del.append(file)
|
1549 |
-
|
1550 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s_best_epoch.pth"))
|
1551 |
-
|
1552 |
-
if epoch >= custom_total_epoch:
|
1553 |
-
lowest_value_rounded = float(lowest_value["value"])
|
1554 |
-
lowest_value_rounded = round(lowest_value_rounded, 3)
|
1555 |
-
|
1556 |
-
logger.info(translations["success_training"].format(epoch=epoch, global_step=global_step, loss_gen_all=round(loss_gen_all.item(), 3)))
|
1557 |
-
logger.info(translations["training_info"].format(lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
1558 |
-
|
1559 |
-
pid_file_path = os.path.join(experiment_dir, "config.json")
|
1560 |
-
|
1561 |
-
with open(pid_file_path, "r") as pid_file:
|
1562 |
-
pid_data = json.load(pid_file)
|
1563 |
-
with open(pid_file_path, "w") as pid_file:
|
1564 |
-
pid_data.pop("process_pids", None)
|
1565 |
-
json.dump(pid_data, pid_file, indent=4)
|
1566 |
-
|
1567 |
-
model_add.append(os.path.join("assets", "weights", f"{model_name}_{epoch}e_{global_step}s.pth"))
|
1568 |
-
done = True
|
1569 |
-
|
1570 |
-
if model_add:
|
1571 |
-
ckpt = (net_g.module.state_dict() if hasattr(net_g, "module") else net_g.state_dict())
|
1572 |
-
|
1573 |
-
for m in model_add:
|
1574 |
-
if not os.path.exists(m): extract_model(ckpt=ckpt, sr=sample_rate, pitch_guidance=pitch_guidance == True, name=model_name, model_dir=m, epoch=epoch, step=global_step, version=version, hps=hps, model_author=model_author)
|
1575 |
-
|
1576 |
-
for m in model_del:
|
1577 |
-
os.remove(m)
|
1578 |
-
|
1579 |
-
lowest_value_rounded = float(lowest_value["value"])
|
1580 |
-
lowest_value_rounded = round(lowest_value_rounded, 3)
|
1581 |
-
|
1582 |
-
if epoch > 1 and overtraining_detector:
|
1583 |
-
remaining_epochs_gen = overtraining_threshold - consecutive_increases_gen
|
1584 |
-
remaining_epochs_disc = (overtraining_threshold * 2) - consecutive_increases_disc
|
1585 |
-
|
1586 |
-
logger.info(translations["model_training_info"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step'], remaining_epochs_gen=remaining_epochs_gen, remaining_epochs_disc=remaining_epochs_disc, smoothed_value_gen=f"{smoothed_value_gen:.3f}", smoothed_value_disc=f"{smoothed_value_disc:.3f}"))
|
1587 |
-
elif epoch > 1 and overtraining_detector == False: logger.info(translations["model_training_info_2"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record(), lowest_value_rounded=lowest_value_rounded, lowest_value_epoch=lowest_value['epoch'], lowest_value_step=lowest_value['step']))
|
1588 |
-
else: logger.info(translations["model_training_info_3"].format(model_name=model_name, epoch=epoch, global_step=global_step, epoch_recorder=epoch_recorder.record()))
|
1589 |
-
|
1590 |
-
last_loss_gen_all = loss_gen_all
|
1591 |
-
|
1592 |
-
if done: os._exit(2333333)
|
1593 |
-
|
1594 |
-
if __name__ == "__main__":
|
1595 |
-
torch.multiprocessing.set_start_method("spawn")
|
1596 |
-
|
1597 |
-
try:
|
1598 |
-
main()
|
1599 |
-
except Exception as e:
|
1600 |
-
logger.error(f"{translations['training_error']} {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/commons.py
DELETED
@@ -1,100 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
from typing import List, Optional
|
5 |
-
|
6 |
-
def init_weights(m, mean=0.0, std=0.01):
|
7 |
-
classname = m.__class__.__name__
|
8 |
-
if classname.find("Conv") != -1: m.weight.data.normal_(mean, std)
|
9 |
-
|
10 |
-
def get_padding(kernel_size, dilation=1):
|
11 |
-
return int((kernel_size * dilation - dilation) / 2)
|
12 |
-
|
13 |
-
def convert_pad_shape(pad_shape):
|
14 |
-
l = pad_shape[::-1]
|
15 |
-
pad_shape = [item for sublist in l for item in sublist]
|
16 |
-
return pad_shape
|
17 |
-
|
18 |
-
def kl_divergence(m_p, logs_p, m_q, logs_q):
|
19 |
-
kl = (logs_q - logs_p) - 0.5
|
20 |
-
kl += (0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q))
|
21 |
-
return kl
|
22 |
-
|
23 |
-
def slice_segments(x: torch.Tensor, ids_str: torch.Tensor, segment_size: int = 4, dim: int = 2):
|
24 |
-
if dim == 2: ret = torch.zeros_like(x[:, :segment_size])
|
25 |
-
elif dim == 3: ret = torch.zeros_like(x[:, :, :segment_size])
|
26 |
-
|
27 |
-
for i in range(x.size(0)):
|
28 |
-
idx_str = ids_str[i].item()
|
29 |
-
idx_end = idx_str + segment_size
|
30 |
-
|
31 |
-
if dim == 2: ret[i] = x[i, idx_str:idx_end]
|
32 |
-
else: ret[i] = x[i, :, idx_str:idx_end]
|
33 |
-
|
34 |
-
return ret
|
35 |
-
|
36 |
-
def rand_slice_segments(x, x_lengths=None, segment_size=4):
|
37 |
-
b, d, t = x.size()
|
38 |
-
|
39 |
-
if x_lengths is None: x_lengths = t
|
40 |
-
|
41 |
-
ids_str_max = x_lengths - segment_size + 1
|
42 |
-
ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
|
43 |
-
ret = slice_segments(x, ids_str, segment_size, dim=3)
|
44 |
-
return ret, ids_str
|
45 |
-
|
46 |
-
def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
|
47 |
-
position = torch.arange(length, dtype=torch.float)
|
48 |
-
num_timescales = channels // 2
|
49 |
-
log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (num_timescales - 1)
|
50 |
-
inv_timescales = min_timescale * torch.exp(torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
|
51 |
-
scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
|
52 |
-
signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
|
53 |
-
signal = torch.nn.functional.pad(signal, [0, 0, 0, channels % 2])
|
54 |
-
signal = signal.view(1, channels, length)
|
55 |
-
return signal
|
56 |
-
|
57 |
-
|
58 |
-
def subsequent_mask(length):
|
59 |
-
mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
|
60 |
-
return mask
|
61 |
-
|
62 |
-
|
63 |
-
@torch.jit.script
|
64 |
-
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
65 |
-
n_channels_int = n_channels[0]
|
66 |
-
in_act = input_a + input_b
|
67 |
-
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
68 |
-
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
69 |
-
acts = t_act * s_act
|
70 |
-
return acts
|
71 |
-
|
72 |
-
|
73 |
-
def convert_pad_shape(pad_shape: List[List[int]]) -> List[int]:
|
74 |
-
return torch.tensor(pad_shape).flip(0).reshape(-1).int().tolist()
|
75 |
-
|
76 |
-
|
77 |
-
def sequence_mask(length: torch.Tensor, max_length: Optional[int] = None):
|
78 |
-
if max_length is None: max_length = length.max()
|
79 |
-
|
80 |
-
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
81 |
-
return x.unsqueeze(0) < length.unsqueeze(1)
|
82 |
-
|
83 |
-
|
84 |
-
def clip_grad_value(parameters, clip_value, norm_type=2):
|
85 |
-
if isinstance(parameters, torch.Tensor): parameters = [parameters]
|
86 |
-
|
87 |
-
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
88 |
-
norm_type = float(norm_type)
|
89 |
-
|
90 |
-
if clip_value is not None: clip_value = float(clip_value)
|
91 |
-
|
92 |
-
total_norm = 0
|
93 |
-
|
94 |
-
for p in parameters:
|
95 |
-
param_norm = p.grad.data.norm(norm_type)
|
96 |
-
total_norm += param_norm.item() ** norm_type
|
97 |
-
if clip_value is not None: p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
98 |
-
|
99 |
-
total_norm = total_norm ** (1.0 / norm_type)
|
100 |
-
return total_norm
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/modules.py
DELETED
@@ -1,80 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
now_dir = os.getcwd()
|
6 |
-
sys.path.append(now_dir)
|
7 |
-
|
8 |
-
from .commons import fused_add_tanh_sigmoid_multiply
|
9 |
-
|
10 |
-
class WaveNet(torch.nn.Module):
|
11 |
-
def __init__(self, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0, p_dropout=0):
|
12 |
-
super(WaveNet, self).__init__()
|
13 |
-
assert kernel_size % 2 == 1
|
14 |
-
|
15 |
-
self.hidden_channels = hidden_channels
|
16 |
-
self.kernel_size = (kernel_size,)
|
17 |
-
self.dilation_rate = dilation_rate
|
18 |
-
|
19 |
-
self.n_layers = n_layers
|
20 |
-
self.gin_channels = gin_channels
|
21 |
-
self.p_dropout = p_dropout
|
22 |
-
|
23 |
-
self.in_layers = torch.nn.ModuleList()
|
24 |
-
self.res_skip_layers = torch.nn.ModuleList()
|
25 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
26 |
-
|
27 |
-
if gin_channels != 0:
|
28 |
-
cond_layer = torch.nn.Conv1d(gin_channels, 2 * hidden_channels * n_layers, 1)
|
29 |
-
self.cond_layer = torch.nn.utils.parametrizations.weight_norm(cond_layer, name="weight")
|
30 |
-
|
31 |
-
dilations = [dilation_rate**i for i in range(n_layers)]
|
32 |
-
paddings = [(kernel_size * d - d) // 2 for d in dilations]
|
33 |
-
|
34 |
-
for i in range(n_layers):
|
35 |
-
in_layer = torch.nn.Conv1d(hidden_channels, 2 * hidden_channels, kernel_size, dilation=dilations[i], padding=paddings[i])
|
36 |
-
in_layer = torch.nn.utils.parametrizations.weight_norm(in_layer, name="weight")
|
37 |
-
|
38 |
-
self.in_layers.append(in_layer)
|
39 |
-
|
40 |
-
res_skip_channels = (hidden_channels if i == n_layers - 1 else 2 * hidden_channels)
|
41 |
-
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
42 |
-
res_skip_layer = torch.nn.utils.parametrizations.weight_norm(res_skip_layer, name="weight")
|
43 |
-
|
44 |
-
self.res_skip_layers.append(res_skip_layer)
|
45 |
-
|
46 |
-
def forward(self, x, x_mask, g=None, **kwargs):
|
47 |
-
output = torch.zeros_like(x)
|
48 |
-
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
49 |
-
|
50 |
-
if g is not None: g = self.cond_layer(g)
|
51 |
-
|
52 |
-
for i in range(self.n_layers):
|
53 |
-
x_in = self.in_layers[i](x)
|
54 |
-
|
55 |
-
if g is not None:
|
56 |
-
cond_offset = i * 2 * self.hidden_channels
|
57 |
-
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
58 |
-
else: g_l = torch.zeros_like(x_in)
|
59 |
-
|
60 |
-
acts = fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
61 |
-
acts = self.drop(acts)
|
62 |
-
|
63 |
-
res_skip_acts = self.res_skip_layers[i](acts)
|
64 |
-
|
65 |
-
if i < self.n_layers - 1:
|
66 |
-
res_acts = res_skip_acts[:, : self.hidden_channels, :]
|
67 |
-
x = (x + res_acts) * x_mask
|
68 |
-
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
69 |
-
else: output = output + res_skip_acts
|
70 |
-
|
71 |
-
return output * x_mask
|
72 |
-
|
73 |
-
def remove_weight_norm(self):
|
74 |
-
if self.gin_channels != 0: torch.nn.utils.remove_weight_norm(self.cond_layer)
|
75 |
-
|
76 |
-
for l in self.in_layers:
|
77 |
-
torch.nn.utils.remove_weight_norm(l)
|
78 |
-
|
79 |
-
for l in self.res_skip_layers:
|
80 |
-
torch.nn.utils.remove_weight_norm(l)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/residuals.py
DELETED
@@ -1,170 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
|
5 |
-
from typing import Optional
|
6 |
-
|
7 |
-
from torch.nn.utils import remove_weight_norm
|
8 |
-
from torch.nn.utils.parametrizations import weight_norm
|
9 |
-
|
10 |
-
now_dir = os.getcwd()
|
11 |
-
sys.path.append(now_dir)
|
12 |
-
|
13 |
-
from .modules import WaveNet
|
14 |
-
from .commons import get_padding, init_weights
|
15 |
-
|
16 |
-
|
17 |
-
LRELU_SLOPE = 0.1
|
18 |
-
|
19 |
-
def create_conv1d_layer(channels, kernel_size, dilation):
|
20 |
-
return weight_norm(torch.nn.Conv1d(channels, channels, kernel_size, 1, dilation=dilation, padding=get_padding(kernel_size, dilation)))
|
21 |
-
|
22 |
-
def apply_mask(tensor, mask):
|
23 |
-
return tensor * mask if mask is not None else tensor
|
24 |
-
|
25 |
-
class ResBlockBase(torch.nn.Module):
|
26 |
-
def __init__(self, channels, kernel_size, dilations):
|
27 |
-
super(ResBlockBase, self).__init__()
|
28 |
-
|
29 |
-
self.convs1 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, d) for d in dilations])
|
30 |
-
self.convs1.apply(init_weights)
|
31 |
-
|
32 |
-
self.convs2 = torch.nn.ModuleList([create_conv1d_layer(channels, kernel_size, 1) for _ in dilations])
|
33 |
-
self.convs2.apply(init_weights)
|
34 |
-
|
35 |
-
def forward(self, x, x_mask=None):
|
36 |
-
for c1, c2 in zip(self.convs1, self.convs2):
|
37 |
-
xt = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
38 |
-
xt = apply_mask(xt, x_mask)
|
39 |
-
xt = torch.nn.functional.leaky_relu(c1(xt), LRELU_SLOPE)
|
40 |
-
xt = apply_mask(xt, x_mask)
|
41 |
-
xt = c2(xt)
|
42 |
-
x = xt + x
|
43 |
-
return apply_mask(x, x_mask)
|
44 |
-
|
45 |
-
def remove_weight_norm(self):
|
46 |
-
for conv in self.convs1 + self.convs2:
|
47 |
-
remove_weight_norm(conv)
|
48 |
-
|
49 |
-
class ResBlock1(ResBlockBase):
|
50 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
|
51 |
-
super(ResBlock1, self).__init__(channels, kernel_size, dilation)
|
52 |
-
|
53 |
-
class ResBlock2(ResBlockBase):
|
54 |
-
def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
|
55 |
-
super(ResBlock2, self).__init__(channels, kernel_size, dilation)
|
56 |
-
|
57 |
-
class Log(torch.nn.Module):
|
58 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
59 |
-
if not reverse:
|
60 |
-
y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
|
61 |
-
logdet = torch.sum(-y, [1, 2])
|
62 |
-
return y, logdet
|
63 |
-
else:
|
64 |
-
x = torch.exp(x) * x_mask
|
65 |
-
return x
|
66 |
-
|
67 |
-
class Flip(torch.nn.Module):
|
68 |
-
def forward(self, x, *args, reverse=False, **kwargs):
|
69 |
-
x = torch.flip(x, [1])
|
70 |
-
if not reverse:
|
71 |
-
logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
|
72 |
-
return x, logdet
|
73 |
-
else: return x
|
74 |
-
|
75 |
-
class ElementwiseAffine(torch.nn.Module):
|
76 |
-
def __init__(self, channels):
|
77 |
-
super().__init__()
|
78 |
-
self.channels = channels
|
79 |
-
self.m = torch.nn.Parameter(torch.zeros(channels, 1))
|
80 |
-
self.logs = torch.nn.Parameter(torch.zeros(channels, 1))
|
81 |
-
|
82 |
-
def forward(self, x, x_mask, reverse=False, **kwargs):
|
83 |
-
if not reverse:
|
84 |
-
y = self.m + torch.exp(self.logs) * x
|
85 |
-
y = y * x_mask
|
86 |
-
logdet = torch.sum(self.logs * x_mask, [1, 2])
|
87 |
-
return y, logdet
|
88 |
-
else:
|
89 |
-
x = (x - self.m) * torch.exp(-self.logs) * x_mask
|
90 |
-
return x
|
91 |
-
|
92 |
-
|
93 |
-
class ResidualCouplingBlock(torch.nn.Module):
|
94 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, n_flows=4, gin_channels=0):
|
95 |
-
super(ResidualCouplingBlock, self).__init__()
|
96 |
-
self.channels = channels
|
97 |
-
self.hidden_channels = hidden_channels
|
98 |
-
self.kernel_size = kernel_size
|
99 |
-
self.dilation_rate = dilation_rate
|
100 |
-
self.n_layers = n_layers
|
101 |
-
self.n_flows = n_flows
|
102 |
-
self.gin_channels = gin_channels
|
103 |
-
|
104 |
-
self.flows = torch.nn.ModuleList()
|
105 |
-
for i in range(n_flows):
|
106 |
-
self.flows.append(ResidualCouplingLayer(channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels, mean_only=True))
|
107 |
-
self.flows.append(Flip())
|
108 |
-
|
109 |
-
def forward(self, x: torch.Tensor, x_mask: torch.Tensor, g: Optional[torch.Tensor] = None, reverse = False):
|
110 |
-
if not reverse:
|
111 |
-
for flow in self.flows:
|
112 |
-
x, _ = flow(x, x_mask, g=g, reverse=reverse)
|
113 |
-
else:
|
114 |
-
for flow in reversed(self.flows):
|
115 |
-
x = flow.forward(x, x_mask, g=g, reverse=reverse)
|
116 |
-
|
117 |
-
return x
|
118 |
-
|
119 |
-
def remove_weight_norm(self):
|
120 |
-
for i in range(self.n_flows):
|
121 |
-
self.flows[i * 2].remove_weight_norm()
|
122 |
-
|
123 |
-
def __prepare_scriptable__(self):
|
124 |
-
for i in range(self.n_flows):
|
125 |
-
for hook in self.flows[i * 2]._forward_pre_hooks.values():
|
126 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flows[i * 2])
|
127 |
-
return self
|
128 |
-
|
129 |
-
|
130 |
-
class ResidualCouplingLayer(torch.nn.Module):
|
131 |
-
def __init__(self, channels, hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=0, gin_channels=0, mean_only=False):
|
132 |
-
assert channels % 2 == 0, "Channels/2"
|
133 |
-
super().__init__()
|
134 |
-
self.channels = channels
|
135 |
-
self.hidden_channels = hidden_channels
|
136 |
-
self.kernel_size = kernel_size
|
137 |
-
self.dilation_rate = dilation_rate
|
138 |
-
self.n_layers = n_layers
|
139 |
-
self.half_channels = channels // 2
|
140 |
-
self.mean_only = mean_only
|
141 |
-
|
142 |
-
self.pre = torch.nn.Conv1d(self.half_channels, hidden_channels, 1)
|
143 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, p_dropout=p_dropout, gin_channels=gin_channels)
|
144 |
-
self.post = torch.nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
|
145 |
-
self.post.weight.data.zero_()
|
146 |
-
self.post.bias.data.zero_()
|
147 |
-
|
148 |
-
def forward(self, x, x_mask, g=None, reverse=False):
|
149 |
-
x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
|
150 |
-
h = self.pre(x0) * x_mask
|
151 |
-
h = self.enc(h, x_mask, g=g)
|
152 |
-
stats = self.post(h) * x_mask
|
153 |
-
|
154 |
-
if not self.mean_only: m, logs = torch.split(stats, [self.half_channels] * 2, 1)
|
155 |
-
else:
|
156 |
-
m = stats
|
157 |
-
logs = torch.zeros_like(m)
|
158 |
-
|
159 |
-
if not reverse:
|
160 |
-
x1 = m + x1 * torch.exp(logs) * x_mask
|
161 |
-
x = torch.cat([x0, x1], 1)
|
162 |
-
logdet = torch.sum(logs, [1, 2])
|
163 |
-
return x, logdet
|
164 |
-
else:
|
165 |
-
x1 = (x1 - m) * torch.exp(-logs) * x_mask
|
166 |
-
x = torch.cat([x0, x1], 1)
|
167 |
-
return x
|
168 |
-
|
169 |
-
def remove_weight_norm(self):
|
170 |
-
self.enc.remove_weight_norm()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/separator.py
DELETED
@@ -1,420 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import time
|
4 |
-
import json
|
5 |
-
import yaml
|
6 |
-
import torch
|
7 |
-
import codecs
|
8 |
-
import hashlib
|
9 |
-
import logging
|
10 |
-
import platform
|
11 |
-
import warnings
|
12 |
-
import requests
|
13 |
-
import subprocess
|
14 |
-
|
15 |
-
import onnxruntime as ort
|
16 |
-
|
17 |
-
from tqdm import tqdm
|
18 |
-
from importlib import metadata, import_module
|
19 |
-
|
20 |
-
now_dir = os.getcwd()
|
21 |
-
sys.path.append(now_dir)
|
22 |
-
|
23 |
-
from main.configs.config import Config
|
24 |
-
translations = Config().translations
|
25 |
-
|
26 |
-
class Separator:
|
27 |
-
def __init__(self, log_level=logging.INFO, log_formatter=None, model_file_dir="assets/model/uvr5", output_dir=None, output_format="wav", output_bitrate=None, normalization_threshold=0.9, output_single_stem=None, invert_using_spec=False, sample_rate=44100, mdx_params={"hop_length": 1024, "segment_size": 256, "overlap": 0.25, "batch_size": 1, "enable_denoise": False}, demucs_params={"segment_size": "Default", "shifts": 2, "overlap": 0.25, "segments_enabled": True}):
|
28 |
-
self.logger = logging.getLogger(__name__)
|
29 |
-
self.logger.setLevel(log_level)
|
30 |
-
self.log_level = log_level
|
31 |
-
self.log_formatter = log_formatter
|
32 |
-
self.log_handler = logging.StreamHandler()
|
33 |
-
|
34 |
-
if self.log_formatter is None: self.log_formatter = logging.Formatter("%(asctime)s - %(levelname)s - %(module)s - %(message)s")
|
35 |
-
|
36 |
-
self.log_handler.setFormatter(self.log_formatter)
|
37 |
-
|
38 |
-
if not self.logger.hasHandlers(): self.logger.addHandler(self.log_handler)
|
39 |
-
if log_level > logging.DEBUG: warnings.filterwarnings("ignore")
|
40 |
-
|
41 |
-
self.logger.info(translations["separator_info"].format(output_dir=output_dir, output_format=output_format))
|
42 |
-
|
43 |
-
self.model_file_dir = model_file_dir
|
44 |
-
|
45 |
-
if output_dir is None:
|
46 |
-
output_dir = os.getcwd()
|
47 |
-
self.logger.info(translations["output_dir_is_none"])
|
48 |
-
|
49 |
-
self.output_dir = output_dir
|
50 |
-
|
51 |
-
os.makedirs(self.model_file_dir, exist_ok=True)
|
52 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
53 |
-
|
54 |
-
self.output_format = output_format
|
55 |
-
self.output_bitrate = output_bitrate
|
56 |
-
|
57 |
-
if self.output_format is None: self.output_format = "wav"
|
58 |
-
|
59 |
-
self.normalization_threshold = normalization_threshold
|
60 |
-
|
61 |
-
if normalization_threshold <= 0 or normalization_threshold > 1: raise ValueError(translations[">0or=1"])
|
62 |
-
|
63 |
-
self.output_single_stem = output_single_stem
|
64 |
-
|
65 |
-
if output_single_stem is not None: self.logger.debug(translations["output_single"].format(output_single_stem=output_single_stem))
|
66 |
-
|
67 |
-
self.invert_using_spec = invert_using_spec
|
68 |
-
if self.invert_using_spec: self.logger.debug(translations["step2"])
|
69 |
-
|
70 |
-
try:
|
71 |
-
self.sample_rate = int(sample_rate)
|
72 |
-
|
73 |
-
if self.sample_rate <= 0: raise ValueError(translations["other_than_zero"].format(sample_rate=self.sample_rate))
|
74 |
-
if self.sample_rate > 12800000: raise ValueError(translations["too_high"].format(sample_rate=self.sample_rate))
|
75 |
-
except ValueError:
|
76 |
-
raise ValueError(translations["sr_not_valid"])
|
77 |
-
|
78 |
-
self.arch_specific_params = {"MDX": mdx_params, "Demucs": demucs_params}
|
79 |
-
self.torch_device = None
|
80 |
-
self.torch_device_cpu = None
|
81 |
-
self.torch_device_mps = None
|
82 |
-
self.onnx_execution_provider = None
|
83 |
-
self.model_instance = None
|
84 |
-
self.model_is_uvr_vip = False
|
85 |
-
self.model_friendly_name = None
|
86 |
-
|
87 |
-
self.setup_accelerated_inferencing_device()
|
88 |
-
|
89 |
-
def setup_accelerated_inferencing_device(self):
|
90 |
-
system_info = self.get_system_info()
|
91 |
-
self.check_ffmpeg_installed()
|
92 |
-
self.log_onnxruntime_packages()
|
93 |
-
self.setup_torch_device(system_info)
|
94 |
-
|
95 |
-
def get_system_info(self):
|
96 |
-
os_name = platform.system()
|
97 |
-
os_version = platform.version()
|
98 |
-
|
99 |
-
self.logger.info(f"{translations['os']}: {os_name} {os_version}")
|
100 |
-
|
101 |
-
system_info = platform.uname()
|
102 |
-
self.logger.info(translations["platform_info"].format(system_info=system_info, node=system_info.node, release=system_info.release, machine=system_info.machine, processor=system_info.processor))
|
103 |
-
|
104 |
-
python_version = platform.python_version()
|
105 |
-
self.logger.info(f"{translations['name_ver'].format(name='python')}: {python_version}")
|
106 |
-
|
107 |
-
pytorch_version = torch.__version__
|
108 |
-
self.logger.info(f"{translations['name_ver'].format(name='pytorch')}: {pytorch_version}")
|
109 |
-
|
110 |
-
return system_info
|
111 |
-
|
112 |
-
def check_ffmpeg_installed(self):
|
113 |
-
try:
|
114 |
-
ffmpeg_version_output = subprocess.check_output(["ffmpeg", "-version"], text=True)
|
115 |
-
first_line = ffmpeg_version_output.splitlines()[0]
|
116 |
-
self.logger.info(f"{translations['install_ffmpeg']}: {first_line}")
|
117 |
-
except FileNotFoundError:
|
118 |
-
self.logger.error(translations["none_ffmpeg"])
|
119 |
-
if "PYTEST_CURRENT_TEST" not in os.environ: raise
|
120 |
-
|
121 |
-
def log_onnxruntime_packages(self):
|
122 |
-
onnxruntime_gpu_package = self.get_package_distribution("onnxruntime-gpu")
|
123 |
-
onnxruntime_cpu_package = self.get_package_distribution("onnxruntime")
|
124 |
-
|
125 |
-
if onnxruntime_gpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='GPU')}: {onnxruntime_gpu_package.version}")
|
126 |
-
if onnxruntime_cpu_package is not None: self.logger.info(f"{translations['install_onnx'].format(pu='CPU')}: {onnxruntime_cpu_package.version}")
|
127 |
-
|
128 |
-
def setup_torch_device(self, system_info):
|
129 |
-
hardware_acceleration_enabled = False
|
130 |
-
ort_providers = ort.get_available_providers()
|
131 |
-
|
132 |
-
self.torch_device_cpu = torch.device("cpu")
|
133 |
-
|
134 |
-
if torch.cuda.is_available():
|
135 |
-
self.configure_cuda(ort_providers)
|
136 |
-
hardware_acceleration_enabled = True
|
137 |
-
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available() and system_info.processor == "arm":
|
138 |
-
self.configure_mps(ort_providers)
|
139 |
-
hardware_acceleration_enabled = True
|
140 |
-
|
141 |
-
if not hardware_acceleration_enabled:
|
142 |
-
self.logger.info(translations["running_in_cpu"])
|
143 |
-
self.torch_device = self.torch_device_cpu
|
144 |
-
self.onnx_execution_provider = ["CPUExecutionProvider"]
|
145 |
-
|
146 |
-
def configure_cuda(self, ort_providers):
|
147 |
-
self.logger.info(translations["running_in_cuda"])
|
148 |
-
self.torch_device = torch.device("cuda")
|
149 |
-
|
150 |
-
if "CUDAExecutionProvider" in ort_providers:
|
151 |
-
self.logger.info(translations["onnx_have"].format(have='CUDAExecutionProvider'))
|
152 |
-
self.onnx_execution_provider = ["CUDAExecutionProvider"]
|
153 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CUDAExecutionProvider'))
|
154 |
-
|
155 |
-
def configure_mps(self, ort_providers):
|
156 |
-
self.logger.info("Cài đặt thiết bị Torch thành MPS")
|
157 |
-
self.torch_device_mps = torch.device("mps")
|
158 |
-
self.torch_device = self.torch_device_mps
|
159 |
-
|
160 |
-
if "CoreMLExecutionProvider" in ort_providers:
|
161 |
-
self.logger.info(translations["onnx_have"].format(have='CoreMLExecutionProvider'))
|
162 |
-
self.onnx_execution_provider = ["CoreMLExecutionProvider"]
|
163 |
-
else: self.logger.warning(translations["onnx_not_have"].format(have='CoreMLExecutionProvider'))
|
164 |
-
|
165 |
-
def get_package_distribution(self, package_name):
|
166 |
-
try:
|
167 |
-
return metadata.distribution(package_name)
|
168 |
-
except metadata.PackageNotFoundError:
|
169 |
-
self.logger.debug(translations["python_not_install"].format(package_name=package_name))
|
170 |
-
return None
|
171 |
-
|
172 |
-
def get_model_hash(self, model_path):
|
173 |
-
self.logger.debug(translations["hash"].format(model_path=model_path))
|
174 |
-
|
175 |
-
try:
|
176 |
-
with open(model_path, "rb") as f:
|
177 |
-
f.seek(-10000 * 1024, 2)
|
178 |
-
return hashlib.md5(f.read()).hexdigest()
|
179 |
-
except IOError as e:
|
180 |
-
self.logger.error(translations["ioerror"].format(e=e))
|
181 |
-
|
182 |
-
return hashlib.md5(open(model_path, "rb").read()).hexdigest()
|
183 |
-
|
184 |
-
def download_file_if_not_exists(self, url, output_path):
|
185 |
-
if os.path.isfile(output_path):
|
186 |
-
self.logger.debug(translations["cancel_download"].format(output_path=output_path))
|
187 |
-
return
|
188 |
-
|
189 |
-
self.logger.debug(translations["download_model"].format(url=url, output_path=output_path))
|
190 |
-
response = requests.get(url, stream=True, timeout=300)
|
191 |
-
|
192 |
-
if response.status_code == 200:
|
193 |
-
total_size_in_bytes = int(response.headers.get("content-length", 0))
|
194 |
-
progress_bar = tqdm(total=total_size_in_bytes)
|
195 |
-
|
196 |
-
with open(output_path, "wb") as f:
|
197 |
-
for chunk in response.iter_content(chunk_size=8192):
|
198 |
-
progress_bar.update(len(chunk))
|
199 |
-
f.write(chunk)
|
200 |
-
|
201 |
-
progress_bar.close()
|
202 |
-
else: raise RuntimeError(translations["download_error"].format(url=url, status_code=response.status_code))
|
203 |
-
|
204 |
-
def print_uvr_vip_message(self):
|
205 |
-
if self.model_is_uvr_vip:
|
206 |
-
self.logger.warning(translations["vip_model"].format(model_friendly_name=self.model_friendly_name))
|
207 |
-
self.logger.warning(translations["vip_print"])
|
208 |
-
|
209 |
-
def list_supported_model_files(self):
|
210 |
-
download_checks_path = os.path.join(self.model_file_dir, "download_checks.json")
|
211 |
-
|
212 |
-
model_downloads_list = json.load(open(download_checks_path, encoding="utf-8"))
|
213 |
-
self.logger.debug(translations["load_download_json"])
|
214 |
-
|
215 |
-
filtered_demucs_v4 = {key: value for key, value in model_downloads_list["demucs_download_list"].items() if key.startswith("Demucs v4")}
|
216 |
-
|
217 |
-
model_files_grouped_by_type = {"MDX": {**model_downloads_list["mdx_download_list"], **model_downloads_list["mdx_download_vip_list"]}, "Demucs": filtered_demucs_v4}
|
218 |
-
return model_files_grouped_by_type
|
219 |
-
|
220 |
-
def download_model_files(self, model_filename):
|
221 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
222 |
-
|
223 |
-
supported_model_files_grouped = self.list_supported_model_files()
|
224 |
-
public_model_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/GEiyie/zbqry_ercb/eryrnfrf/qbjaybnq/nyy_choyvp_hie_zbqryf", "rot13")
|
225 |
-
vip_model_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/Nawbx0109/nv_zntvp/eryrnfrf/qbjaybnq/i5", "rot13")
|
226 |
-
|
227 |
-
audio_separator_models_repo_url_prefix = codecs.decode("uggcf://tvguho.pbz/abznqxnenbxr/clguba-nhqvb-frcnengbe/eryrnfrf/qbjaybnq/zbqry-pbasvtf", "rot13")
|
228 |
-
|
229 |
-
yaml_config_filename = None
|
230 |
-
|
231 |
-
self.logger.debug(translations["search_model"].format(model_filename=model_filename))
|
232 |
-
|
233 |
-
for model_type, model_list in supported_model_files_grouped.items():
|
234 |
-
for model_friendly_name, model_download_list in model_list.items():
|
235 |
-
self.model_is_uvr_vip = "VIP" in model_friendly_name
|
236 |
-
model_repo_url_prefix = vip_model_repo_url_prefix if self.model_is_uvr_vip else public_model_repo_url_prefix
|
237 |
-
|
238 |
-
if isinstance(model_download_list, str) and model_download_list == model_filename:
|
239 |
-
self.logger.debug(translations["single_model"].format(model_friendly_name=model_friendly_name))
|
240 |
-
self.model_friendly_name = model_friendly_name
|
241 |
-
|
242 |
-
try:
|
243 |
-
self.download_file_if_not_exists(f"{model_repo_url_prefix}/{model_filename}", model_path)
|
244 |
-
except RuntimeError:
|
245 |
-
self.logger.debug(translations["not_found_model"])
|
246 |
-
self.download_file_if_not_exists(f"{audio_separator_models_repo_url_prefix}/{model_filename}", model_path)
|
247 |
-
|
248 |
-
self.print_uvr_vip_message()
|
249 |
-
|
250 |
-
self.logger.debug(translations["single_model_path"].format(model_path=model_path))
|
251 |
-
|
252 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
253 |
-
elif isinstance(model_download_list, dict):
|
254 |
-
this_model_matches_input_filename = False
|
255 |
-
|
256 |
-
for file_name, file_url in model_download_list.items():
|
257 |
-
if file_name == model_filename or file_url == model_filename:
|
258 |
-
self.logger.debug(translations["find_model"].format(model_filename=model_filename, model_friendly_name=model_friendly_name))
|
259 |
-
this_model_matches_input_filename = True
|
260 |
-
|
261 |
-
if this_model_matches_input_filename:
|
262 |
-
self.logger.debug(translations["find_models"].format(model_friendly_name=model_friendly_name))
|
263 |
-
self.model_friendly_name = model_friendly_name
|
264 |
-
self.print_uvr_vip_message()
|
265 |
-
|
266 |
-
for config_key, config_value in model_download_list.items():
|
267 |
-
self.logger.debug(f"{translations['find_path']}: {config_key} -> {config_value}")
|
268 |
-
|
269 |
-
if config_value.startswith("http"): self.download_file_if_not_exists(config_value, os.path.join(self.model_file_dir, config_key))
|
270 |
-
elif config_key.endswith(".ckpt"):
|
271 |
-
try:
|
272 |
-
download_url = f"{model_repo_url_prefix}/{config_key}"
|
273 |
-
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
|
274 |
-
except RuntimeError:
|
275 |
-
self.logger.debug(translations["not_found_model_warehouse"])
|
276 |
-
download_url = f"{audio_separator_models_repo_url_prefix}/{config_key}"
|
277 |
-
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_key))
|
278 |
-
|
279 |
-
if model_filename.endswith(".yaml"):
|
280 |
-
self.logger.warning(translations["yaml_warning"].format(model_filename=model_filename))
|
281 |
-
self.logger.warning(translations["yaml_warning_2"].format(config_key=config_key))
|
282 |
-
self.logger.warning(translations["yaml_warning_3"])
|
283 |
-
|
284 |
-
model_filename = config_key
|
285 |
-
model_path = os.path.join(self.model_file_dir, f"{model_filename}")
|
286 |
-
|
287 |
-
yaml_config_filename = config_value
|
288 |
-
yaml_config_filepath = os.path.join(self.model_file_dir, yaml_config_filename)
|
289 |
-
|
290 |
-
try:
|
291 |
-
url = codecs.decode("uggcf://enj.tvguhohfrepbagrag.pbz/GEiyie/nccyvpngvba_qngn/znva/zqk_zbqry_qngn/zqk_p_pbasvtf", "rot13")
|
292 |
-
yaml_config_url = f"{url}/{yaml_config_filename}"
|
293 |
-
self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
|
294 |
-
except RuntimeError:
|
295 |
-
self.logger.debug(translations["yaml_debug"])
|
296 |
-
yaml_config_url = f"{audio_separator_models_repo_url_prefix}/{yaml_config_filename}"
|
297 |
-
self.download_file_if_not_exists(f"{yaml_config_url}", yaml_config_filepath)
|
298 |
-
else:
|
299 |
-
download_url = f"{model_repo_url_prefix}/{config_value}"
|
300 |
-
self.download_file_if_not_exists(download_url, os.path.join(self.model_file_dir, config_value))
|
301 |
-
|
302 |
-
self.logger.debug(translations["download_model_friendly"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
303 |
-
|
304 |
-
return model_filename, model_type, model_friendly_name, model_path, yaml_config_filename
|
305 |
-
|
306 |
-
raise ValueError(translations["not_found_model_2"].format(model_filename=model_filename))
|
307 |
-
|
308 |
-
def load_model_data_from_yaml(self, yaml_config_filename):
|
309 |
-
model_data_yaml_filepath = os.path.join(self.model_file_dir, yaml_config_filename) if not os.path.exists(yaml_config_filename) else yaml_config_filename
|
310 |
-
|
311 |
-
self.logger.debug(translations["load_yaml"].format(model_data_yaml_filepath=model_data_yaml_filepath))
|
312 |
-
|
313 |
-
model_data = yaml.load(open(model_data_yaml_filepath, encoding="utf-8"), Loader=yaml.FullLoader)
|
314 |
-
self.logger.debug(translations["load_yaml_2"].format(model_data=model_data))
|
315 |
-
|
316 |
-
if "roformer" in model_data_yaml_filepath: model_data["is_roformer"] = True
|
317 |
-
|
318 |
-
return model_data
|
319 |
-
|
320 |
-
def load_model_data_using_hash(self, model_path):
|
321 |
-
mdx_model_data_url = codecs.decode("uggcf://enj.tvguhohfrepbagrag.pbz/GEiyie/nccyvpngvba_qngn/znva/zqk_zbqry_qngn/zbqry_qngn_arj.wfba", "rot13")
|
322 |
-
|
323 |
-
self.logger.debug(translations["hash_md5"])
|
324 |
-
model_hash = self.get_model_hash(model_path)
|
325 |
-
self.logger.debug(translations["model_hash"].format(model_path=model_path, model_hash=model_hash))
|
326 |
-
|
327 |
-
mdx_model_data_path = os.path.join(self.model_file_dir, "mdx_model_data.json")
|
328 |
-
self.logger.debug(translations["mdx_data"].format(mdx_model_data_path=mdx_model_data_path))
|
329 |
-
self.download_file_if_not_exists(mdx_model_data_url, mdx_model_data_path)
|
330 |
-
|
331 |
-
self.logger.debug(translations["load_mdx"])
|
332 |
-
mdx_model_data_object = json.load(open(mdx_model_data_path, encoding="utf-8"))
|
333 |
-
|
334 |
-
if model_hash in mdx_model_data_object: model_data = mdx_model_data_object[model_hash]
|
335 |
-
else: raise ValueError(translations["model_not_support"].format(model_hash=model_hash))
|
336 |
-
|
337 |
-
self.logger.debug(translations["uvr_json"].format(model_hash=model_hash, model_data=model_data))
|
338 |
-
|
339 |
-
return model_data
|
340 |
-
|
341 |
-
def load_model(self, model_filename):
|
342 |
-
self.logger.info(translations["loading_model"].format(model_filename=model_filename))
|
343 |
-
|
344 |
-
load_model_start_time = time.perf_counter()
|
345 |
-
|
346 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
347 |
-
model_name = model_filename.split(".")[0]
|
348 |
-
self.logger.debug(translations["download_model_friendly_2"].format(model_friendly_name=model_friendly_name, model_path=model_path))
|
349 |
-
|
350 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
351 |
-
|
352 |
-
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path)
|
353 |
-
|
354 |
-
common_params = {
|
355 |
-
"logger": self.logger,
|
356 |
-
"log_level": self.log_level,
|
357 |
-
"torch_device": self.torch_device,
|
358 |
-
"torch_device_cpu": self.torch_device_cpu,
|
359 |
-
"torch_device_mps": self.torch_device_mps,
|
360 |
-
"onnx_execution_provider": self.onnx_execution_provider,
|
361 |
-
"model_name": model_name,
|
362 |
-
"model_path": model_path,
|
363 |
-
"model_data": model_data,
|
364 |
-
"output_format": self.output_format,
|
365 |
-
"output_bitrate": self.output_bitrate,
|
366 |
-
"output_dir": self.output_dir,
|
367 |
-
"normalization_threshold": self.normalization_threshold,
|
368 |
-
"output_single_stem": self.output_single_stem,
|
369 |
-
"invert_using_spec": self.invert_using_spec,
|
370 |
-
"sample_rate": self.sample_rate,
|
371 |
-
}
|
372 |
-
|
373 |
-
separator_classes = {"MDX": "mdx_separator.MDXSeparator", "Demucs": "demucs_separator.DemucsSeparator"}
|
374 |
-
|
375 |
-
if model_type not in self.arch_specific_params or model_type not in separator_classes: raise ValueError(translations["model_type_not_support"].format(model_type=model_type))
|
376 |
-
if model_type == "Demucs" and sys.version_info < (3, 10): raise Exception(translations["demucs_not_support_python<3.10"])
|
377 |
-
|
378 |
-
self.logger.debug(f"{translations['import_module']} {model_type}: {separator_classes[model_type]}")
|
379 |
-
|
380 |
-
module_name, class_name = separator_classes[model_type].split(".")
|
381 |
-
module = import_module(f"main.library.architectures.{module_name}")
|
382 |
-
separator_class = getattr(module, class_name)
|
383 |
-
|
384 |
-
self.logger.debug(f"{translations['initialization']} {model_type}: {separator_class}")
|
385 |
-
self.model_instance = separator_class(common_config=common_params, arch_config=self.arch_specific_params[model_type])
|
386 |
-
|
387 |
-
self.logger.debug(translations["loading_model_success"])
|
388 |
-
self.logger.info(f"{translations['loading_model_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - load_model_start_time)))}")
|
389 |
-
|
390 |
-
def separate(self, audio_file_path):
|
391 |
-
self.logger.info(f"{translations['starting_separator']}: {audio_file_path}")
|
392 |
-
separate_start_time = time.perf_counter()
|
393 |
-
|
394 |
-
self.logger.debug(translations["normalization"].format(normalization_threshold=self.normalization_threshold))
|
395 |
-
|
396 |
-
output_files = self.model_instance.separate(audio_file_path)
|
397 |
-
|
398 |
-
self.model_instance.clear_gpu_cache()
|
399 |
-
|
400 |
-
self.model_instance.clear_file_specific_paths()
|
401 |
-
|
402 |
-
self.print_uvr_vip_message()
|
403 |
-
|
404 |
-
self.logger.debug(translations["separator_success_3"])
|
405 |
-
self.logger.info(f"{translations['separator_duration']}: {time.strftime('%H:%M:%S', time.gmtime(int(time.perf_counter() - separate_start_time)))}")
|
406 |
-
|
407 |
-
return output_files
|
408 |
-
|
409 |
-
def download_model_and_data(self, model_filename):
|
410 |
-
self.logger.info(translations["loading_separator_model"].format(model_filename=model_filename))
|
411 |
-
|
412 |
-
model_filename, model_type, model_friendly_name, model_path, yaml_config_filename = self.download_model_files(model_filename)
|
413 |
-
|
414 |
-
if model_path.lower().endswith(".yaml"): yaml_config_filename = model_path
|
415 |
-
|
416 |
-
model_data = self.load_model_data_from_yaml(yaml_config_filename) if yaml_config_filename is not None else self.load_model_data_using_hash(model_path)
|
417 |
-
|
418 |
-
model_data_dict_size = len(model_data)
|
419 |
-
|
420 |
-
self.logger.info(translations["downloading_model"].format(model_type=model_type, model_friendly_name=model_friendly_name, model_path=model_path, model_data_dict_size=model_data_dict_size))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/algorithm/synthesizers.py
DELETED
@@ -1,590 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
|
6 |
-
from typing import Optional
|
7 |
-
from torch.nn.utils import remove_weight_norm
|
8 |
-
from torch.nn.utils.parametrizations import weight_norm
|
9 |
-
|
10 |
-
now_dir = os.getcwd()
|
11 |
-
sys.path.append(now_dir)
|
12 |
-
|
13 |
-
from .modules import WaveNet
|
14 |
-
from .residuals import ResidualCouplingBlock, ResBlock1, ResBlock2, LRELU_SLOPE
|
15 |
-
from .commons import init_weights, slice_segments, rand_slice_segments, sequence_mask, convert_pad_shape
|
16 |
-
|
17 |
-
class Generator(torch.nn.Module):
|
18 |
-
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=0):
|
19 |
-
super(Generator, self).__init__()
|
20 |
-
|
21 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
22 |
-
self.num_upsamples = len(upsample_rates)
|
23 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
24 |
-
resblock = ResBlock1 if resblock == "1" else ResBlock2
|
25 |
-
|
26 |
-
self.ups_and_resblocks = torch.nn.ModuleList()
|
27 |
-
|
28 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
29 |
-
self.ups_and_resblocks.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), upsample_initial_channel // (2 ** (i + 1)), k, u, padding=(k - u) // 2)))
|
30 |
-
|
31 |
-
ch = upsample_initial_channel // (2 ** (i + 1))
|
32 |
-
|
33 |
-
for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
34 |
-
self.ups_and_resblocks.append(resblock(ch, k, d))
|
35 |
-
|
36 |
-
self.conv_post = torch.nn.Conv1d(ch, 1, 7, 1, padding=3, bias=False)
|
37 |
-
self.ups_and_resblocks.apply(init_weights)
|
38 |
-
|
39 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
40 |
-
|
41 |
-
def forward(self, x: torch.Tensor, g: Optional[torch.Tensor] = None):
|
42 |
-
x = self.conv_pre(x)
|
43 |
-
if g is not None: x = x + self.cond(g)
|
44 |
-
|
45 |
-
resblock_idx = 0
|
46 |
-
|
47 |
-
for _ in range(self.num_upsamples):
|
48 |
-
x = torch.nn.functional.leaky_relu(x, LRELU_SLOPE)
|
49 |
-
x = self.ups_and_resblocks[resblock_idx](x)
|
50 |
-
resblock_idx += 1
|
51 |
-
xs = 0
|
52 |
-
|
53 |
-
for _ in range(self.num_kernels):
|
54 |
-
xs += self.ups_and_resblocks[resblock_idx](x)
|
55 |
-
resblock_idx += 1
|
56 |
-
|
57 |
-
x = xs / self.num_kernels
|
58 |
-
|
59 |
-
x = torch.nn.functional.leaky_relu(x)
|
60 |
-
x = self.conv_post(x)
|
61 |
-
x = torch.tanh(x)
|
62 |
-
|
63 |
-
return x
|
64 |
-
|
65 |
-
def __prepare_scriptable__(self):
|
66 |
-
for l in self.ups_and_resblocks:
|
67 |
-
for hook in l._forward_pre_hooks.values():
|
68 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(l)
|
69 |
-
|
70 |
-
return self
|
71 |
-
def remove_weight_norm(self):
|
72 |
-
for l in self.ups_and_resblocks:
|
73 |
-
remove_weight_norm(l)
|
74 |
-
|
75 |
-
class SineGen(torch.nn.Module):
|
76 |
-
def __init__(self, samp_rate, harmonic_num=0, sine_amp=0.1, noise_std=0.003, voiced_threshold=0, flag_for_pulse=False):
|
77 |
-
super(SineGen, self).__init__()
|
78 |
-
self.sine_amp = sine_amp
|
79 |
-
self.noise_std = noise_std
|
80 |
-
self.harmonic_num = harmonic_num
|
81 |
-
self.dim = self.harmonic_num + 1
|
82 |
-
self.sample_rate = samp_rate
|
83 |
-
self.voiced_threshold = voiced_threshold
|
84 |
-
|
85 |
-
def _f02uv(self, f0):
|
86 |
-
uv = torch.ones_like(f0)
|
87 |
-
uv = uv * (f0 > self.voiced_threshold)
|
88 |
-
return uv
|
89 |
-
|
90 |
-
def forward(self, f0: torch.Tensor, upp: int):
|
91 |
-
with torch.no_grad():
|
92 |
-
f0 = f0[:, None].transpose(1, 2)
|
93 |
-
f0_buf = torch.zeros(f0.shape[0], f0.shape[1], self.dim, device=f0.device)
|
94 |
-
f0_buf[:, :, 0] = f0[:, :, 0]
|
95 |
-
f0_buf[:, :, 1:] = (f0_buf[:, :, 0:1] * torch.arange(2, self.harmonic_num + 2, device=f0.device)[None, None, :])
|
96 |
-
rad_values = (f0_buf / float(self.sample_rate)) % 1
|
97 |
-
rand_ini = torch.rand(f0_buf.shape[0], f0_buf.shape[2], device=f0_buf.device)
|
98 |
-
rand_ini[:, 0] = 0
|
99 |
-
rad_values[:, 0, :] = rad_values[:, 0, :] + rand_ini
|
100 |
-
tmp_over_one = torch.cumsum(rad_values, 1)
|
101 |
-
tmp_over_one *= upp
|
102 |
-
tmp_over_one = torch.nn.functional.interpolate(tmp_over_one.transpose(2, 1), scale_factor=float(upp), mode="linear", align_corners=True).transpose(2, 1)
|
103 |
-
rad_values = torch.nn.functional.interpolate(rad_values.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
104 |
-
tmp_over_one %= 1
|
105 |
-
tmp_over_one_idx = (tmp_over_one[:, 1:, :] - tmp_over_one[:, :-1, :]) < 0
|
106 |
-
cumsum_shift = torch.zeros_like(rad_values)
|
107 |
-
cumsum_shift[:, 1:, :] = tmp_over_one_idx * -1.0
|
108 |
-
sine_waves = torch.sin(torch.cumsum(rad_values + cumsum_shift, dim=1) * 2 * torch.pi)
|
109 |
-
sine_waves = sine_waves * self.sine_amp
|
110 |
-
uv = self._f02uv(f0)
|
111 |
-
uv = torch.nn.functional.interpolate(uv.transpose(2, 1), scale_factor=float(upp), mode="nearest").transpose(2, 1)
|
112 |
-
noise_amp = uv * self.noise_std + (1 - uv) * self.sine_amp / 3
|
113 |
-
noise = noise_amp * torch.randn_like(sine_waves)
|
114 |
-
sine_waves = sine_waves * uv + noise
|
115 |
-
|
116 |
-
return sine_waves, uv, noise
|
117 |
-
|
118 |
-
class SourceModuleHnNSF(torch.nn.Module):
|
119 |
-
def __init__(self, sample_rate, harmonic_num=0, sine_amp=0.1, add_noise_std=0.003, voiced_threshod=0, is_half=True):
|
120 |
-
super(SourceModuleHnNSF, self).__init__()
|
121 |
-
|
122 |
-
self.sine_amp = sine_amp
|
123 |
-
self.noise_std = add_noise_std
|
124 |
-
self.is_half = is_half
|
125 |
-
|
126 |
-
self.l_sin_gen = SineGen(sample_rate, harmonic_num, sine_amp, add_noise_std, voiced_threshod)
|
127 |
-
self.l_linear = torch.nn.Linear(harmonic_num + 1, 1)
|
128 |
-
self.l_tanh = torch.nn.Tanh()
|
129 |
-
|
130 |
-
def forward(self, x: torch.Tensor, upsample_factor: int = 1):
|
131 |
-
sine_wavs, uv, _ = self.l_sin_gen(x, upsample_factor)
|
132 |
-
sine_wavs = sine_wavs.to(dtype=self.l_linear.weight.dtype)
|
133 |
-
sine_merge = self.l_tanh(self.l_linear(sine_wavs))
|
134 |
-
return sine_merge, None, None
|
135 |
-
|
136 |
-
|
137 |
-
class GeneratorNSF(torch.nn.Module):
|
138 |
-
def __init__(self, initial_channel, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels, sr, is_half=False):
|
139 |
-
super(GeneratorNSF, self).__init__()
|
140 |
-
|
141 |
-
self.num_kernels = len(resblock_kernel_sizes)
|
142 |
-
self.num_upsamples = len(upsample_rates)
|
143 |
-
self.f0_upsamp = torch.nn.Upsample(scale_factor=math.prod(upsample_rates))
|
144 |
-
self.m_source = SourceModuleHnNSF(sample_rate=sr, harmonic_num=0, is_half=is_half)
|
145 |
-
|
146 |
-
self.conv_pre = torch.nn.Conv1d(initial_channel, upsample_initial_channel, 7, 1, padding=3)
|
147 |
-
resblock_cls = ResBlock1 if resblock == "1" else ResBlock2
|
148 |
-
|
149 |
-
self.ups = torch.nn.ModuleList()
|
150 |
-
self.noise_convs = torch.nn.ModuleList()
|
151 |
-
|
152 |
-
channels = [upsample_initial_channel // (2 ** (i + 1)) for i in range(len(upsample_rates))]
|
153 |
-
stride_f0s = [math.prod(upsample_rates[i + 1 :]) if i + 1 < len(upsample_rates) else 1 for i in range(len(upsample_rates))]
|
154 |
-
|
155 |
-
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
|
156 |
-
self.ups.append(weight_norm(torch.nn.ConvTranspose1d(upsample_initial_channel // (2**i), channels[i], k, u, padding=(k - u) // 2)))
|
157 |
-
self.noise_convs.append(torch.nn.Conv1d(1, channels[i], kernel_size=(stride_f0s[i] * 2 if stride_f0s[i] > 1 else 1), stride=stride_f0s[i], padding=(stride_f0s[i] // 2 if stride_f0s[i] > 1 else 0)))
|
158 |
-
|
159 |
-
self.resblocks = torch.nn.ModuleList([resblock_cls(channels[i], k, d) for i in range(len(self.ups)) for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes)])
|
160 |
-
|
161 |
-
self.conv_post = torch.nn.Conv1d(channels[-1], 1, 7, 1, padding=3, bias=False)
|
162 |
-
self.ups.apply(init_weights)
|
163 |
-
|
164 |
-
if gin_channels != 0: self.cond = torch.nn.Conv1d(gin_channels, upsample_initial_channel, 1)
|
165 |
-
|
166 |
-
self.upp = math.prod(upsample_rates)
|
167 |
-
self.lrelu_slope = LRELU_SLOPE
|
168 |
-
|
169 |
-
def forward(self, x, f0, g: Optional[torch.Tensor] = None):
|
170 |
-
har_source, _, _ = self.m_source(f0, self.upp)
|
171 |
-
har_source = har_source.transpose(1, 2)
|
172 |
-
x = self.conv_pre(x)
|
173 |
-
|
174 |
-
if g is not None: x = x + self.cond(g)
|
175 |
-
|
176 |
-
for i, (ups, noise_convs) in enumerate(zip(self.ups, self.noise_convs)):
|
177 |
-
x = torch.nn.functional.leaky_relu(x, self.lrelu_slope)
|
178 |
-
x = ups(x)
|
179 |
-
x = x + noise_convs(har_source)
|
180 |
-
|
181 |
-
xs = sum([resblock(x) for j, resblock in enumerate(self.resblocks) if j in range(i * self.num_kernels, (i + 1) * self.num_kernels)])
|
182 |
-
x = xs / self.num_kernels
|
183 |
-
|
184 |
-
x = torch.nn.functional.leaky_relu(x)
|
185 |
-
x = torch.tanh(self.conv_post(x))
|
186 |
-
return x
|
187 |
-
|
188 |
-
def remove_weight_norm(self):
|
189 |
-
for l in self.ups:
|
190 |
-
remove_weight_norm(l)
|
191 |
-
for l in self.resblocks:
|
192 |
-
l.remove_weight_norm()
|
193 |
-
|
194 |
-
def __prepare_scriptable__(self):
|
195 |
-
for l in self.ups:
|
196 |
-
for hook in l._forward_pre_hooks.values():
|
197 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): remove_weight_norm(l)
|
198 |
-
|
199 |
-
for l in self.resblocks:
|
200 |
-
for hook in l._forward_pre_hooks.values():
|
201 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): remove_weight_norm(l)
|
202 |
-
return self
|
203 |
-
|
204 |
-
class LayerNorm(torch.nn.Module):
|
205 |
-
def __init__(self, channels, eps=1e-5):
|
206 |
-
super().__init__()
|
207 |
-
self.eps = eps
|
208 |
-
self.gamma = torch.nn.Parameter(torch.ones(channels))
|
209 |
-
self.beta = torch.nn.Parameter(torch.zeros(channels))
|
210 |
-
|
211 |
-
def forward(self, x):
|
212 |
-
x = x.transpose(1, -1)
|
213 |
-
x = torch.nn.functional.layer_norm(x, (x.size(-1),), self.gamma, self.beta, self.eps)
|
214 |
-
return x.transpose(1, -1)
|
215 |
-
|
216 |
-
class MultiHeadAttention(torch.nn.Module):
|
217 |
-
def __init__(self, channels, out_channels, n_heads, p_dropout=0.0, window_size=None, heads_share=True, block_length=None, proximal_bias=False, proximal_init=False):
|
218 |
-
super().__init__()
|
219 |
-
assert channels % n_heads == 0
|
220 |
-
|
221 |
-
self.channels = channels
|
222 |
-
self.out_channels = out_channels
|
223 |
-
self.n_heads = n_heads
|
224 |
-
self.p_dropout = p_dropout
|
225 |
-
self.window_size = window_size
|
226 |
-
self.heads_share = heads_share
|
227 |
-
self.block_length = block_length
|
228 |
-
self.proximal_bias = proximal_bias
|
229 |
-
self.proximal_init = proximal_init
|
230 |
-
self.attn = None
|
231 |
-
self.k_channels = channels // n_heads
|
232 |
-
self.conv_q = torch.nn.Conv1d(channels, channels, 1)
|
233 |
-
self.conv_k = torch.nn.Conv1d(channels, channels, 1)
|
234 |
-
self.conv_v = torch.nn.Conv1d(channels, channels, 1)
|
235 |
-
self.conv_o = torch.nn.Conv1d(channels, out_channels, 1)
|
236 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
237 |
-
|
238 |
-
if window_size is not None:
|
239 |
-
n_heads_rel = 1 if heads_share else n_heads
|
240 |
-
rel_stddev = self.k_channels**-0.5
|
241 |
-
|
242 |
-
self.emb_rel_k = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
243 |
-
self.emb_rel_v = torch.nn.Parameter(torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels) * rel_stddev)
|
244 |
-
|
245 |
-
torch.nn.init.xavier_uniform_(self.conv_q.weight)
|
246 |
-
torch.nn.init.xavier_uniform_(self.conv_k.weight)
|
247 |
-
torch.nn.init.xavier_uniform_(self.conv_v.weight)
|
248 |
-
|
249 |
-
if proximal_init:
|
250 |
-
with torch.no_grad():
|
251 |
-
self.conv_k.weight.copy_(self.conv_q.weight)
|
252 |
-
self.conv_k.bias.copy_(self.conv_q.bias)
|
253 |
-
|
254 |
-
def forward(self, x, c, attn_mask=None):
|
255 |
-
q = self.conv_q(x)
|
256 |
-
k = self.conv_k(c)
|
257 |
-
v = self.conv_v(c)
|
258 |
-
|
259 |
-
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
260 |
-
|
261 |
-
x = self.conv_o(x)
|
262 |
-
return x
|
263 |
-
|
264 |
-
def attention(self, query, key, value, mask=None):
|
265 |
-
b, d, t_s, t_t = (*key.size(), query.size(2))
|
266 |
-
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
267 |
-
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
268 |
-
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
269 |
-
|
270 |
-
scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
|
271 |
-
if self.window_size is not None:
|
272 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
273 |
-
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
274 |
-
rel_logits = self._matmul_with_relative_keys(query / math.sqrt(self.k_channels), key_relative_embeddings)
|
275 |
-
scores_local = self._relative_position_to_absolute_position(rel_logits)
|
276 |
-
scores = scores + scores_local
|
277 |
-
|
278 |
-
if self.proximal_bias:
|
279 |
-
assert t_s == t_t, "t_s == t_t"
|
280 |
-
scores = scores + self._attention_bias_proximal(t_s).to(device=scores.device, dtype=scores.dtype)
|
281 |
-
|
282 |
-
if mask is not None:
|
283 |
-
scores = scores.masked_fill(mask == 0, -1e4)
|
284 |
-
if self.block_length is not None:
|
285 |
-
assert (t_s == t_t), "(t_s == t_t)"
|
286 |
-
|
287 |
-
block_mask = (torch.ones_like(scores).triu(-self.block_length).tril(self.block_length))
|
288 |
-
scores = scores.masked_fill(block_mask == 0, -1e4)
|
289 |
-
|
290 |
-
p_attn = torch.nn.functional.softmax(scores, dim=-1)
|
291 |
-
p_attn = self.drop(p_attn)
|
292 |
-
output = torch.matmul(p_attn, value)
|
293 |
-
|
294 |
-
if self.window_size is not None:
|
295 |
-
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
296 |
-
value_relative_embeddings = self._get_relative_embeddings(self.emb_rel_v, t_s)
|
297 |
-
output = output + self._matmul_with_relative_values(relative_weights, value_relative_embeddings)
|
298 |
-
|
299 |
-
output = (output.transpose(2, 3).contiguous().view(b, d, t_t))
|
300 |
-
return output, p_attn
|
301 |
-
|
302 |
-
def _matmul_with_relative_values(self, x, y):
|
303 |
-
ret = torch.matmul(x, y.unsqueeze(0))
|
304 |
-
return ret
|
305 |
-
|
306 |
-
def _matmul_with_relative_keys(self, x, y):
|
307 |
-
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
308 |
-
return ret
|
309 |
-
|
310 |
-
def _get_relative_embeddings(self, relative_embeddings, length):
|
311 |
-
pad_length = max(length - (self.window_size + 1), 0)
|
312 |
-
slice_start_position = max((self.window_size + 1) - length, 0)
|
313 |
-
slice_end_position = slice_start_position + 2 * length - 1
|
314 |
-
|
315 |
-
if pad_length > 0: padded_relative_embeddings = torch.nn.functional.pad(relative_embeddings, convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]))
|
316 |
-
else: padded_relative_embeddings = relative_embeddings
|
317 |
-
|
318 |
-
used_relative_embeddings = padded_relative_embeddings[:, slice_start_position:slice_end_position]
|
319 |
-
return used_relative_embeddings
|
320 |
-
|
321 |
-
def _relative_position_to_absolute_position(self, x):
|
322 |
-
batch, heads, length, _ = x.size()
|
323 |
-
|
324 |
-
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
325 |
-
|
326 |
-
x_flat = x.view([batch, heads, length * 2 * length])
|
327 |
-
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [0, length - 1]]))
|
328 |
-
|
329 |
-
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[:, :, :length, length - 1 :]
|
330 |
-
|
331 |
-
return x_final
|
332 |
-
|
333 |
-
def _absolute_position_to_relative_position(self, x):
|
334 |
-
batch, heads, length, _ = x.size()
|
335 |
-
x = torch.nn.functional.pad(x, convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]]))
|
336 |
-
x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
|
337 |
-
x_flat = torch.nn.functional.pad(x_flat, convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
338 |
-
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
339 |
-
return x_final
|
340 |
-
|
341 |
-
def _attention_bias_proximal(self, length):
|
342 |
-
r = torch.arange(length, dtype=torch.float32)
|
343 |
-
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
344 |
-
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
345 |
-
|
346 |
-
|
347 |
-
class FFN(torch.nn.Module):
|
348 |
-
def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0.0, activation=None, causal=False):
|
349 |
-
super().__init__()
|
350 |
-
|
351 |
-
self.in_channels = in_channels
|
352 |
-
self.out_channels = out_channels
|
353 |
-
self.filter_channels = filter_channels
|
354 |
-
self.kernel_size = kernel_size
|
355 |
-
self.p_dropout = p_dropout
|
356 |
-
self.activation = activation
|
357 |
-
self.causal = causal
|
358 |
-
|
359 |
-
if causal: self.padding = self._causal_padding
|
360 |
-
else: self.padding = self._same_padding
|
361 |
-
|
362 |
-
self.conv_1 = torch.nn.Conv1d(in_channels, filter_channels, kernel_size)
|
363 |
-
self.conv_2 = torch.nn.Conv1d(filter_channels, out_channels, kernel_size)
|
364 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
365 |
-
|
366 |
-
def forward(self, x, x_mask):
|
367 |
-
x = self.conv_1(self.padding(x * x_mask))
|
368 |
-
|
369 |
-
if self.activation == "gelu": x = x * torch.sigmoid(1.702 * x)
|
370 |
-
else: x = torch.relu(x)
|
371 |
-
|
372 |
-
x = self.drop(x)
|
373 |
-
x = self.conv_2(self.padding(x * x_mask))
|
374 |
-
return x * x_mask
|
375 |
-
|
376 |
-
def _causal_padding(self, x):
|
377 |
-
if self.kernel_size == 1: return x
|
378 |
-
|
379 |
-
pad_l = self.kernel_size - 1
|
380 |
-
pad_r = 0
|
381 |
-
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
382 |
-
x = torch.nn.functional.pad(x, convert_pad_shape(padding))
|
383 |
-
return x
|
384 |
-
|
385 |
-
def _same_padding(self, x):
|
386 |
-
if self.kernel_size == 1: return x
|
387 |
-
|
388 |
-
pad_l = (self.kernel_size - 1) // 2
|
389 |
-
pad_r = self.kernel_size // 2
|
390 |
-
|
391 |
-
padding = [[0, 0], [0, 0], [pad_l, pad_r]]
|
392 |
-
x = torch.nn.functional.pad(x, convert_pad_shape(padding))
|
393 |
-
|
394 |
-
return x
|
395 |
-
|
396 |
-
class Encoder(torch.nn.Module):
|
397 |
-
def __init__(self, hidden_channels, filter_channels, n_heads, n_layers, kernel_size=1, p_dropout=0.0, window_size=10, **kwargs):
|
398 |
-
super().__init__()
|
399 |
-
self.hidden_channels = hidden_channels
|
400 |
-
self.filter_channels = filter_channels
|
401 |
-
self.n_heads = n_heads
|
402 |
-
self.n_layers = n_layers
|
403 |
-
self.kernel_size = kernel_size
|
404 |
-
self.p_dropout = p_dropout
|
405 |
-
self.window_size = window_size
|
406 |
-
|
407 |
-
self.drop = torch.nn.Dropout(p_dropout)
|
408 |
-
self.attn_layers = torch.nn.ModuleList()
|
409 |
-
self.norm_layers_1 = torch.nn.ModuleList()
|
410 |
-
self.ffn_layers = torch.nn.ModuleList()
|
411 |
-
self.norm_layers_2 = torch.nn.ModuleList()
|
412 |
-
|
413 |
-
for _ in range(self.n_layers):
|
414 |
-
self.attn_layers.append(MultiHeadAttention(hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout, window_size=window_size))
|
415 |
-
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
416 |
-
self.ffn_layers.append(FFN(hidden_channels, hidden_channels, filter_channels, kernel_size, p_dropout=p_dropout))
|
417 |
-
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
418 |
-
|
419 |
-
def forward(self, x, x_mask):
|
420 |
-
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
421 |
-
x = x * x_mask
|
422 |
-
|
423 |
-
for i in range(self.n_layers):
|
424 |
-
y = self.attn_layers[i](x, x, attn_mask)
|
425 |
-
y = self.drop(y)
|
426 |
-
x = self.norm_layers_1[i](x + y)
|
427 |
-
|
428 |
-
y = self.ffn_layers[i](x, x_mask)
|
429 |
-
y = self.drop(y)
|
430 |
-
x = self.norm_layers_2[i](x + y)
|
431 |
-
|
432 |
-
x = x * x_mask
|
433 |
-
return x
|
434 |
-
|
435 |
-
|
436 |
-
class TextEncoder(torch.nn.Module):
|
437 |
-
def __init__(self, out_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, embedding_dim, f0=True):
|
438 |
-
super(TextEncoder, self).__init__()
|
439 |
-
self.out_channels = out_channels
|
440 |
-
self.hidden_channels = hidden_channels
|
441 |
-
self.filter_channels = filter_channels
|
442 |
-
self.n_heads = n_heads
|
443 |
-
self.n_layers = n_layers
|
444 |
-
self.kernel_size = kernel_size
|
445 |
-
self.p_dropout = float(p_dropout)
|
446 |
-
self.emb_phone = torch.nn.Linear(embedding_dim, hidden_channels)
|
447 |
-
self.lrelu = torch.nn.LeakyReLU(0.1, inplace=True)
|
448 |
-
|
449 |
-
if f0: self.emb_pitch = torch.nn.Embedding(256, hidden_channels)
|
450 |
-
|
451 |
-
self.encoder = Encoder(hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout))
|
452 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
453 |
-
|
454 |
-
def forward(self, phone: torch.Tensor, pitch: Optional[torch.Tensor], lengths: torch.Tensor):
|
455 |
-
if pitch is None: x = self.emb_phone(phone)
|
456 |
-
else: x = self.emb_phone(phone) + self.emb_pitch(pitch)
|
457 |
-
|
458 |
-
x = x * math.sqrt(self.hidden_channels)
|
459 |
-
x = self.lrelu(x)
|
460 |
-
x = torch.transpose(x, 1, -1)
|
461 |
-
x_mask = torch.unsqueeze(sequence_mask(lengths, x.size(2)), 1).to(x.dtype)
|
462 |
-
x = self.encoder(x * x_mask, x_mask)
|
463 |
-
stats = self.proj(x) * x_mask
|
464 |
-
|
465 |
-
m, logs = torch.split(stats, self.out_channels, dim=1)
|
466 |
-
return m, logs, x_mask
|
467 |
-
|
468 |
-
|
469 |
-
class PosteriorEncoder(torch.nn.Module):
|
470 |
-
def __init__(self, in_channels, out_channels, hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=0):
|
471 |
-
super(PosteriorEncoder, self).__init__()
|
472 |
-
self.in_channels = in_channels
|
473 |
-
self.out_channels = out_channels
|
474 |
-
self.hidden_channels = hidden_channels
|
475 |
-
self.kernel_size = kernel_size
|
476 |
-
self.dilation_rate = dilation_rate
|
477 |
-
self.n_layers = n_layers
|
478 |
-
self.gin_channels = gin_channels
|
479 |
-
|
480 |
-
self.pre = torch.nn.Conv1d(in_channels, hidden_channels, 1)
|
481 |
-
self.enc = WaveNet(hidden_channels, kernel_size, dilation_rate, n_layers, gin_channels=gin_channels)
|
482 |
-
self.proj = torch.nn.Conv1d(hidden_channels, out_channels * 2, 1)
|
483 |
-
|
484 |
-
def forward(self, x: torch.Tensor, x_lengths: torch.Tensor, g: Optional[torch.Tensor] = None):
|
485 |
-
x_mask = torch.unsqueeze(sequence_mask(x_lengths, x.size(2)), 1).to(x.dtype)
|
486 |
-
x = self.pre(x) * x_mask
|
487 |
-
x = self.enc(x, x_mask, g=g)
|
488 |
-
stats = self.proj(x) * x_mask
|
489 |
-
m, logs = torch.split(stats, self.out_channels, dim=1)
|
490 |
-
z = (m + torch.randn_like(m) * torch.exp(logs)) * x_mask
|
491 |
-
return z, m, logs, x_mask
|
492 |
-
|
493 |
-
def remove_weight_norm(self):
|
494 |
-
self.enc.remove_weight_norm()
|
495 |
-
|
496 |
-
def __prepare_scriptable__(self):
|
497 |
-
for hook in self.enc._forward_pre_hooks.values():
|
498 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.enc)
|
499 |
-
|
500 |
-
return self
|
501 |
-
|
502 |
-
class Synthesizer(torch.nn.Module):
|
503 |
-
def __init__(self, spec_channels, segment_size, inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, p_dropout, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, spk_embed_dim, gin_channels, sr, use_f0, text_enc_hidden_dim=768, **kwargs):
|
504 |
-
super(Synthesizer, self).__init__()
|
505 |
-
self.spec_channels = spec_channels
|
506 |
-
self.inter_channels = inter_channels
|
507 |
-
self.hidden_channels = hidden_channels
|
508 |
-
self.filter_channels = filter_channels
|
509 |
-
self.n_heads = n_heads
|
510 |
-
self.n_layers = n_layers
|
511 |
-
self.kernel_size = kernel_size
|
512 |
-
self.p_dropout = float(p_dropout)
|
513 |
-
self.resblock = resblock
|
514 |
-
self.resblock_kernel_sizes = resblock_kernel_sizes
|
515 |
-
self.resblock_dilation_sizes = resblock_dilation_sizes
|
516 |
-
self.upsample_rates = upsample_rates
|
517 |
-
self.upsample_initial_channel = upsample_initial_channel
|
518 |
-
self.upsample_kernel_sizes = upsample_kernel_sizes
|
519 |
-
self.segment_size = segment_size
|
520 |
-
self.gin_channels = gin_channels
|
521 |
-
self.spk_embed_dim = spk_embed_dim
|
522 |
-
self.use_f0 = use_f0
|
523 |
-
|
524 |
-
self.enc_p = TextEncoder(inter_channels, hidden_channels, filter_channels, n_heads, n_layers, kernel_size, float(p_dropout), text_enc_hidden_dim, f0=use_f0)
|
525 |
-
|
526 |
-
if use_f0: self.dec = GeneratorNSF(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels, sr=sr, is_half=kwargs["is_half"])
|
527 |
-
else: self.dec = Generator(inter_channels, resblock, resblock_kernel_sizes, resblock_dilation_sizes, upsample_rates, upsample_initial_channel, upsample_kernel_sizes, gin_channels=gin_channels)
|
528 |
-
|
529 |
-
self.enc_q = PosteriorEncoder(spec_channels, inter_channels, hidden_channels, 5, 1, 16, gin_channels=gin_channels)
|
530 |
-
self.flow = ResidualCouplingBlock(inter_channels, hidden_channels, 5, 1, 3, gin_channels=gin_channels)
|
531 |
-
self.emb_g = torch.nn.Embedding(self.spk_embed_dim, gin_channels)
|
532 |
-
|
533 |
-
def remove_weight_norm(self):
|
534 |
-
self.dec.remove_weight_norm()
|
535 |
-
self.flow.remove_weight_norm()
|
536 |
-
self.enc_q.remove_weight_norm()
|
537 |
-
|
538 |
-
def __prepare_scriptable__(self):
|
539 |
-
for hook in self.dec._forward_pre_hooks.values():
|
540 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.dec)
|
541 |
-
|
542 |
-
for hook in self.flow._forward_pre_hooks.values():
|
543 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.flow)
|
544 |
-
|
545 |
-
if hasattr(self, "enc_q"):
|
546 |
-
for hook in self.enc_q._forward_pre_hooks.values():
|
547 |
-
if (hook.__module__ == "torch.nn.utils.parametrizations.weight_norm" and hook.__class__.__name__ == "WeightNorm"): torch.nn.utils.remove_weight_norm(self.enc_q)
|
548 |
-
|
549 |
-
return self
|
550 |
-
|
551 |
-
@torch.jit.ignore
|
552 |
-
def forward(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, pitchf: Optional[torch.Tensor] = None, y: torch.Tensor = None, y_lengths: torch.Tensor = None, ds: Optional[torch.Tensor] = None):
|
553 |
-
g = self.emb_g(ds).unsqueeze(-1)
|
554 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
555 |
-
|
556 |
-
if y is not None:
|
557 |
-
z, m_q, logs_q, y_mask = self.enc_q(y, y_lengths, g=g)
|
558 |
-
z_p = self.flow(z, y_mask, g=g)
|
559 |
-
z_slice, ids_slice = rand_slice_segments(z, y_lengths, self.segment_size)
|
560 |
-
|
561 |
-
if self.use_f0:
|
562 |
-
pitchf = slice_segments(pitchf, ids_slice, self.segment_size, 2)
|
563 |
-
o = self.dec(z_slice, pitchf, g=g)
|
564 |
-
else: o = self.dec(z_slice, g=g)
|
565 |
-
|
566 |
-
return o, ids_slice, x_mask, y_mask, (z, z_p, m_p, logs_p, m_q, logs_q)
|
567 |
-
else: return None, None, x_mask, None, (None, None, m_p, logs_p, None, None)
|
568 |
-
|
569 |
-
@torch.jit.export
|
570 |
-
def infer(self, phone: torch.Tensor, phone_lengths: torch.Tensor, pitch: Optional[torch.Tensor] = None, nsff0: Optional[torch.Tensor] = None, sid: torch.Tensor = None, rate: Optional[torch.Tensor] = None):
|
571 |
-
g = self.emb_g(sid).unsqueeze(-1)
|
572 |
-
m_p, logs_p, x_mask = self.enc_p(phone, pitch, phone_lengths)
|
573 |
-
z_p = (m_p + torch.exp(logs_p) * torch.randn_like(m_p) * 0.66666) * x_mask
|
574 |
-
|
575 |
-
if rate is not None:
|
576 |
-
assert isinstance(rate, torch.Tensor)
|
577 |
-
head = int(z_p.shape[2] * (1.0 - rate.item()))
|
578 |
-
z_p = z_p[:, :, head:]
|
579 |
-
x_mask = x_mask[:, :, head:]
|
580 |
-
|
581 |
-
if self.use_f0: nsff0 = nsff0[:, head:]
|
582 |
-
|
583 |
-
if self.use_f0:
|
584 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
585 |
-
o = self.dec(z * x_mask, nsff0, g=g)
|
586 |
-
else:
|
587 |
-
z = self.flow(z_p, x_mask, g=g, reverse=True)
|
588 |
-
o = self.dec(z * x_mask, g=g)
|
589 |
-
|
590 |
-
return o, x_mask, (z, z_p, m_p, logs_p)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/demucs_separator.py
DELETED
@@ -1,340 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import yaml
|
4 |
-
import torch
|
5 |
-
|
6 |
-
import numpy as np
|
7 |
-
import typing as tp
|
8 |
-
|
9 |
-
from pathlib import Path
|
10 |
-
from hashlib import sha256
|
11 |
-
|
12 |
-
now_dir = os.getcwd()
|
13 |
-
sys.path.append(now_dir)
|
14 |
-
|
15 |
-
from main.configs.config import Config
|
16 |
-
from main.library.uvr5_separator import spec_utils
|
17 |
-
from main.library.uvr5_separator.demucs.hdemucs import HDemucs
|
18 |
-
from main.library.uvr5_separator.demucs.states import load_model
|
19 |
-
from main.library.uvr5_separator.demucs.apply import BagOfModels, Model
|
20 |
-
from main.library.uvr5_separator.common_separator import CommonSeparator
|
21 |
-
from main.library.uvr5_separator.demucs.apply import apply_model, demucs_segments
|
22 |
-
|
23 |
-
|
24 |
-
translations = Config().translations
|
25 |
-
|
26 |
-
DEMUCS_4_SOURCE = ["drums", "bass", "other", "vocals"]
|
27 |
-
|
28 |
-
DEMUCS_2_SOURCE_MAPPER = {
|
29 |
-
CommonSeparator.INST_STEM: 0,
|
30 |
-
CommonSeparator.VOCAL_STEM: 1
|
31 |
-
}
|
32 |
-
|
33 |
-
DEMUCS_4_SOURCE_MAPPER = {
|
34 |
-
CommonSeparator.BASS_STEM: 0,
|
35 |
-
CommonSeparator.DRUM_STEM: 1,
|
36 |
-
CommonSeparator.OTHER_STEM: 2,
|
37 |
-
CommonSeparator.VOCAL_STEM: 3
|
38 |
-
}
|
39 |
-
|
40 |
-
DEMUCS_6_SOURCE_MAPPER = {
|
41 |
-
CommonSeparator.BASS_STEM: 0,
|
42 |
-
CommonSeparator.DRUM_STEM: 1,
|
43 |
-
CommonSeparator.OTHER_STEM: 2,
|
44 |
-
CommonSeparator.VOCAL_STEM: 3,
|
45 |
-
CommonSeparator.GUITAR_STEM: 4,
|
46 |
-
CommonSeparator.PIANO_STEM: 5,
|
47 |
-
}
|
48 |
-
|
49 |
-
|
50 |
-
REMOTE_ROOT = Path(__file__).parent / "remote"
|
51 |
-
|
52 |
-
PRETRAINED_MODELS = {
|
53 |
-
"demucs": "e07c671f",
|
54 |
-
"demucs48_hq": "28a1282c",
|
55 |
-
"demucs_extra": "3646af93",
|
56 |
-
"demucs_quantized": "07afea75",
|
57 |
-
"tasnet": "beb46fac",
|
58 |
-
"tasnet_extra": "df3777b2",
|
59 |
-
"demucs_unittest": "09ebc15f",
|
60 |
-
}
|
61 |
-
|
62 |
-
|
63 |
-
sys.path.insert(0, os.path.join(os.getcwd(), "main", "library", "uvr5_separator"))
|
64 |
-
|
65 |
-
AnyModel = tp.Union[Model, BagOfModels]
|
66 |
-
|
67 |
-
|
68 |
-
class DemucsSeparator(CommonSeparator):
|
69 |
-
def __init__(self, common_config, arch_config):
|
70 |
-
super().__init__(config=common_config)
|
71 |
-
|
72 |
-
self.segment_size = arch_config.get("segment_size", "Default")
|
73 |
-
self.shifts = arch_config.get("shifts", 2)
|
74 |
-
self.overlap = arch_config.get("overlap", 0.25)
|
75 |
-
self.segments_enabled = arch_config.get("segments_enabled", True)
|
76 |
-
|
77 |
-
self.logger.debug(translations["demucs_info"].format(segment_size=self.segment_size, segments_enabled=self.segments_enabled))
|
78 |
-
self.logger.debug(translations["demucs_info_2"].format(shifts=self.shifts, overlap=self.overlap))
|
79 |
-
|
80 |
-
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
81 |
-
|
82 |
-
self.audio_file_path = None
|
83 |
-
self.audio_file_base = None
|
84 |
-
self.demucs_model_instance = None
|
85 |
-
|
86 |
-
self.logger.info(translations["start_demucs"])
|
87 |
-
|
88 |
-
def separate(self, audio_file_path):
|
89 |
-
self.logger.debug(translations["start_separator"])
|
90 |
-
|
91 |
-
source = None
|
92 |
-
stem_source = None
|
93 |
-
|
94 |
-
inst_source = {}
|
95 |
-
|
96 |
-
self.audio_file_path = audio_file_path
|
97 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
98 |
-
|
99 |
-
self.logger.debug(translations["prepare_mix"])
|
100 |
-
mix = self.prepare_mix(self.audio_file_path)
|
101 |
-
|
102 |
-
self.logger.debug(translations["demix"].format(shape=mix.shape))
|
103 |
-
|
104 |
-
self.logger.debug(translations["cancel_mix"])
|
105 |
-
|
106 |
-
self.demucs_model_instance = HDemucs(sources=DEMUCS_4_SOURCE)
|
107 |
-
self.demucs_model_instance = get_demucs_model(name=os.path.splitext(os.path.basename(self.model_path))[0], repo=Path(os.path.dirname(self.model_path)))
|
108 |
-
self.demucs_model_instance = demucs_segments(self.segment_size, self.demucs_model_instance)
|
109 |
-
self.demucs_model_instance.to(self.torch_device)
|
110 |
-
self.demucs_model_instance.eval()
|
111 |
-
|
112 |
-
self.logger.debug(translations["model_review"])
|
113 |
-
|
114 |
-
source = self.demix_demucs(mix)
|
115 |
-
|
116 |
-
del self.demucs_model_instance
|
117 |
-
self.clear_gpu_cache()
|
118 |
-
self.logger.debug(translations["del_gpu_cache_after_demix"])
|
119 |
-
|
120 |
-
output_files = []
|
121 |
-
self.logger.debug(translations["process_output_file"])
|
122 |
-
|
123 |
-
if isinstance(inst_source, np.ndarray):
|
124 |
-
self.logger.debug(translations["process_ver"])
|
125 |
-
source_reshape = spec_utils.reshape_sources(inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]], source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]])
|
126 |
-
inst_source[self.demucs_source_map[CommonSeparator.VOCAL_STEM]] = source_reshape
|
127 |
-
source = inst_source
|
128 |
-
|
129 |
-
if isinstance(source, np.ndarray):
|
130 |
-
source_length = len(source)
|
131 |
-
self.logger.debug(translations["source_length"].format(source_length=source_length))
|
132 |
-
|
133 |
-
match source_length:
|
134 |
-
case 2:
|
135 |
-
self.logger.debug(translations["set_map"].format(part="2"))
|
136 |
-
self.demucs_source_map = DEMUCS_2_SOURCE_MAPPER
|
137 |
-
case 6:
|
138 |
-
self.logger.debug(translations["set_map"].format(part="6"))
|
139 |
-
self.demucs_source_map = DEMUCS_6_SOURCE_MAPPER
|
140 |
-
case _:
|
141 |
-
self.logger.debug(translations["set_map"].format(part="2"))
|
142 |
-
self.demucs_source_map = DEMUCS_4_SOURCE_MAPPER
|
143 |
-
|
144 |
-
self.logger.debug(translations["process_all_part"])
|
145 |
-
|
146 |
-
for stem_name, stem_value in self.demucs_source_map.items():
|
147 |
-
if self.output_single_stem is not None:
|
148 |
-
if stem_name.lower() != self.output_single_stem.lower():
|
149 |
-
self.logger.debug(translations["skip_part"].format(stem_name=stem_name, output_single_stem=self.output_single_stem))
|
150 |
-
continue
|
151 |
-
|
152 |
-
stem_path = os.path.join(f"{self.audio_file_base}_({stem_name})_{self.model_name}.{self.output_format.lower()}")
|
153 |
-
stem_source = source[stem_value].T
|
154 |
-
|
155 |
-
self.final_process(stem_path, stem_source, stem_name)
|
156 |
-
output_files.append(stem_path)
|
157 |
-
|
158 |
-
return output_files
|
159 |
-
|
160 |
-
def demix_demucs(self, mix):
|
161 |
-
self.logger.debug(translations["starting_demix_demucs"])
|
162 |
-
|
163 |
-
processed = {}
|
164 |
-
mix = torch.tensor(mix, dtype=torch.float32)
|
165 |
-
ref = mix.mean(0)
|
166 |
-
mix = (mix - ref.mean()) / ref.std()
|
167 |
-
mix_infer = mix
|
168 |
-
|
169 |
-
with torch.no_grad():
|
170 |
-
self.logger.debug(translations["model_infer"])
|
171 |
-
sources = apply_model(model=self.demucs_model_instance, mix=mix_infer[None], shifts=self.shifts, split=self.segments_enabled, overlap=self.overlap, static_shifts=1 if self.shifts == 0 else self.shifts, set_progress_bar=None, device=self.torch_device, progress=True)[0]
|
172 |
-
|
173 |
-
sources = (sources * ref.std() + ref.mean()).cpu().numpy()
|
174 |
-
sources[[0, 1]] = sources[[1, 0]]
|
175 |
-
|
176 |
-
processed[mix] = sources[:, :, 0:None].copy()
|
177 |
-
|
178 |
-
sources = list(processed.values())
|
179 |
-
sources = [s[:, :, 0:None] for s in sources]
|
180 |
-
sources = np.concatenate(sources, axis=-1)
|
181 |
-
return sources
|
182 |
-
|
183 |
-
|
184 |
-
class ModelOnlyRepo:
|
185 |
-
def has_model(self, sig: str) -> bool:
|
186 |
-
raise NotImplementedError()
|
187 |
-
|
188 |
-
def get_model(self, sig: str) -> Model:
|
189 |
-
raise NotImplementedError()
|
190 |
-
|
191 |
-
|
192 |
-
class RemoteRepo(ModelOnlyRepo):
|
193 |
-
def __init__(self, models: tp.Dict[str, str]):
|
194 |
-
self._models = models
|
195 |
-
|
196 |
-
def has_model(self, sig: str) -> bool:
|
197 |
-
return sig in self._models
|
198 |
-
|
199 |
-
def get_model(self, sig: str) -> Model:
|
200 |
-
try:
|
201 |
-
url = self._models[sig]
|
202 |
-
except KeyError:
|
203 |
-
raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
|
204 |
-
|
205 |
-
pkg = torch.hub.load_state_dict_from_url(url, map_location="cpu", check_hash=True)
|
206 |
-
return load_model(pkg)
|
207 |
-
|
208 |
-
|
209 |
-
class LocalRepo(ModelOnlyRepo):
|
210 |
-
def __init__(self, root: Path):
|
211 |
-
self.root = root
|
212 |
-
self.scan()
|
213 |
-
|
214 |
-
def scan(self):
|
215 |
-
self._models = {}
|
216 |
-
self._checksums = {}
|
217 |
-
|
218 |
-
for file in self.root.iterdir():
|
219 |
-
if file.suffix == ".th":
|
220 |
-
if "-" in file.stem:
|
221 |
-
xp_sig, checksum = file.stem.split("-")
|
222 |
-
self._checksums[xp_sig] = checksum
|
223 |
-
else: xp_sig = file.stem
|
224 |
-
|
225 |
-
if xp_sig in self._models: raise RuntimeError(translations["del_all_but_one"].format(xp_sig=xp_sig))
|
226 |
-
|
227 |
-
self._models[xp_sig] = file
|
228 |
-
|
229 |
-
def has_model(self, sig: str) -> bool:
|
230 |
-
return sig in self._models
|
231 |
-
|
232 |
-
def get_model(self, sig: str) -> Model:
|
233 |
-
try:
|
234 |
-
file = self._models[sig]
|
235 |
-
except KeyError:
|
236 |
-
raise RuntimeError(translations["not_found_model_signature"].format(sig=sig))
|
237 |
-
|
238 |
-
if sig in self._checksums: check_checksum(file, self._checksums[sig])
|
239 |
-
|
240 |
-
return load_model(file)
|
241 |
-
|
242 |
-
|
243 |
-
class BagOnlyRepo:
|
244 |
-
def __init__(self, root: Path, model_repo: ModelOnlyRepo):
|
245 |
-
self.root = root
|
246 |
-
self.model_repo = model_repo
|
247 |
-
self.scan()
|
248 |
-
|
249 |
-
def scan(self):
|
250 |
-
self._bags = {}
|
251 |
-
|
252 |
-
for file in self.root.iterdir():
|
253 |
-
if file.suffix == ".yaml": self._bags[file.stem] = file
|
254 |
-
|
255 |
-
def has_model(self, name: str) -> bool:
|
256 |
-
return name in self._bags
|
257 |
-
|
258 |
-
def get_model(self, name: str) -> BagOfModels:
|
259 |
-
try:
|
260 |
-
yaml_file = self._bags[name]
|
261 |
-
except KeyError:
|
262 |
-
raise RuntimeError(translations["name_not_pretrained"].format(name=name))
|
263 |
-
|
264 |
-
bag = yaml.safe_load(open(yaml_file))
|
265 |
-
signatures = bag["models"]
|
266 |
-
models = [self.model_repo.get_model(sig) for sig in signatures]
|
267 |
-
|
268 |
-
weights = bag.get("weights")
|
269 |
-
segment = bag.get("segment")
|
270 |
-
|
271 |
-
return BagOfModels(models, weights, segment)
|
272 |
-
|
273 |
-
|
274 |
-
class AnyModelRepo:
|
275 |
-
def __init__(self, model_repo: ModelOnlyRepo, bag_repo: BagOnlyRepo):
|
276 |
-
self.model_repo = model_repo
|
277 |
-
self.bag_repo = bag_repo
|
278 |
-
|
279 |
-
def has_model(self, name_or_sig: str) -> bool:
|
280 |
-
return self.model_repo.has_model(name_or_sig) or self.bag_repo.has_model(name_or_sig)
|
281 |
-
|
282 |
-
def get_model(self, name_or_sig: str) -> AnyModel:
|
283 |
-
if self.model_repo.has_model(name_or_sig): return self.model_repo.get_model(name_or_sig)
|
284 |
-
else: return self.bag_repo.get_model(name_or_sig)
|
285 |
-
|
286 |
-
|
287 |
-
def check_checksum(path: Path, checksum: str):
|
288 |
-
sha = sha256()
|
289 |
-
|
290 |
-
with open(path, "rb") as file:
|
291 |
-
while 1:
|
292 |
-
buf = file.read(2**20)
|
293 |
-
if not buf: break
|
294 |
-
|
295 |
-
sha.update(buf)
|
296 |
-
|
297 |
-
actual_checksum = sha.hexdigest()[: len(checksum)]
|
298 |
-
|
299 |
-
if actual_checksum != checksum: raise RuntimeError(translations["invalid_checksum"].format(path=path, checksum=checksum, actual_checksum=actual_checksum))
|
300 |
-
|
301 |
-
|
302 |
-
def _parse_remote_files(remote_file_list) -> tp.Dict[str, str]:
|
303 |
-
root: str = ""
|
304 |
-
models: tp.Dict[str, str] = {}
|
305 |
-
|
306 |
-
for line in remote_file_list.read_text().split("\n"):
|
307 |
-
line = line.strip()
|
308 |
-
|
309 |
-
if line.startswith("#"): continue
|
310 |
-
elif line.startswith("root:"): root = line.split(":", 1)[1].strip()
|
311 |
-
else:
|
312 |
-
sig = line.split("-", 1)[0]
|
313 |
-
assert sig not in models
|
314 |
-
|
315 |
-
models[sig] = "https://dl.fbaipublicfiles.com/demucs/mdx_final/" + root + line
|
316 |
-
|
317 |
-
return models
|
318 |
-
|
319 |
-
|
320 |
-
def get_demucs_model(name: str, repo: tp.Optional[Path] = None):
|
321 |
-
if name == "demucs_unittest": return HDemucs(channels=4, sources=DEMUCS_4_SOURCE)
|
322 |
-
|
323 |
-
model_repo: ModelOnlyRepo
|
324 |
-
|
325 |
-
if repo is None:
|
326 |
-
models = _parse_remote_files(REMOTE_ROOT / "files.txt")
|
327 |
-
model_repo = RemoteRepo(models)
|
328 |
-
bag_repo = BagOnlyRepo(REMOTE_ROOT, model_repo)
|
329 |
-
else:
|
330 |
-
if not repo.is_dir(): print(translations["repo_must_be_folder"].format(repo=repo))
|
331 |
-
|
332 |
-
model_repo = LocalRepo(repo)
|
333 |
-
bag_repo = BagOnlyRepo(repo, model_repo)
|
334 |
-
|
335 |
-
any_repo = AnyModelRepo(model_repo, bag_repo)
|
336 |
-
|
337 |
-
model = any_repo.get_model(name)
|
338 |
-
model.eval()
|
339 |
-
|
340 |
-
return model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/architectures/mdx_separator.py
DELETED
@@ -1,370 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import onnx
|
4 |
-
import torch
|
5 |
-
import platform
|
6 |
-
import onnx2torch
|
7 |
-
|
8 |
-
import numpy as np
|
9 |
-
import onnxruntime as ort
|
10 |
-
|
11 |
-
from tqdm import tqdm
|
12 |
-
|
13 |
-
now_dir = os.getcwd()
|
14 |
-
sys.path.append(now_dir)
|
15 |
-
|
16 |
-
from main.configs.config import Config
|
17 |
-
from main.library.uvr5_separator import spec_utils
|
18 |
-
from main.library.uvr5_separator.common_separator import CommonSeparator
|
19 |
-
|
20 |
-
|
21 |
-
translations = Config().translations
|
22 |
-
|
23 |
-
class MDXSeparator(CommonSeparator):
|
24 |
-
def __init__(self, common_config, arch_config):
|
25 |
-
super().__init__(config=common_config)
|
26 |
-
|
27 |
-
self.segment_size = arch_config.get("segment_size")
|
28 |
-
self.overlap = arch_config.get("overlap")
|
29 |
-
self.batch_size = arch_config.get("batch_size", 1)
|
30 |
-
self.hop_length = arch_config.get("hop_length")
|
31 |
-
self.enable_denoise = arch_config.get("enable_denoise")
|
32 |
-
self.logger.debug(translations["mdx_info"].format(batch_size=self.batch_size, segment_size=self.segment_size))
|
33 |
-
self.logger.debug(translations["mdx_info_2"].format(overlap=self.overlap, hop_length=self.hop_length, enable_denoise=self.enable_denoise))
|
34 |
-
self.compensate = self.model_data["compensate"]
|
35 |
-
self.dim_f = self.model_data["mdx_dim_f_set"]
|
36 |
-
self.dim_t = 2 ** self.model_data["mdx_dim_t_set"]
|
37 |
-
self.n_fft = self.model_data["mdx_n_fft_scale_set"]
|
38 |
-
self.config_yaml = self.model_data.get("config_yaml", None)
|
39 |
-
self.logger.debug(f"{translations['mdx_info_3']}: compensate = {self.compensate}, dim_f = {self.dim_f}, dim_t = {self.dim_t}, n_fft = {self.n_fft}")
|
40 |
-
self.logger.debug(f"{translations['mdx_info_3']}: config_yaml = {self.config_yaml}")
|
41 |
-
self.load_model()
|
42 |
-
self.n_bins = 0
|
43 |
-
self.trim = 0
|
44 |
-
self.chunk_size = 0
|
45 |
-
self.gen_size = 0
|
46 |
-
self.stft = None
|
47 |
-
self.primary_source = None
|
48 |
-
self.secondary_source = None
|
49 |
-
self.audio_file_path = None
|
50 |
-
self.audio_file_base = None
|
51 |
-
|
52 |
-
|
53 |
-
def load_model(self):
|
54 |
-
self.logger.debug(translations["load_model_onnx"])
|
55 |
-
|
56 |
-
if self.segment_size == self.dim_t:
|
57 |
-
ort_session_options = ort.SessionOptions()
|
58 |
-
|
59 |
-
if self.log_level > 10: ort_session_options.log_severity_level = 3
|
60 |
-
else: ort_session_options.log_severity_level = 0
|
61 |
-
|
62 |
-
ort_inference_session = ort.InferenceSession(self.model_path, providers=self.onnx_execution_provider, sess_options=ort_session_options)
|
63 |
-
self.model_run = lambda spek: ort_inference_session.run(None, {"input": spek.cpu().numpy()})[0]
|
64 |
-
self.logger.debug(translations["load_model_onnx_success"])
|
65 |
-
else:
|
66 |
-
if platform.system() == 'Windows':
|
67 |
-
onnx_model = onnx.load(self.model_path)
|
68 |
-
self.model_run = onnx2torch.convert(onnx_model)
|
69 |
-
else: self.model_run = onnx2torch.convert(self.model_path)
|
70 |
-
|
71 |
-
self.model_run.to(self.torch_device).eval()
|
72 |
-
self.logger.warning(translations["onnx_to_pytorch"])
|
73 |
-
|
74 |
-
def separate(self, audio_file_path):
|
75 |
-
self.audio_file_path = audio_file_path
|
76 |
-
self.audio_file_base = os.path.splitext(os.path.basename(audio_file_path))[0]
|
77 |
-
|
78 |
-
self.logger.debug(translations["mix"].format(audio_file_path=self.audio_file_path))
|
79 |
-
mix = self.prepare_mix(self.audio_file_path)
|
80 |
-
|
81 |
-
self.logger.debug(translations["normalization_demix"])
|
82 |
-
mix = spec_utils.normalize(wave=mix, max_peak=self.normalization_threshold)
|
83 |
-
|
84 |
-
source = self.demix(mix)
|
85 |
-
self.logger.debug(translations["mix_success"])
|
86 |
-
|
87 |
-
output_files = []
|
88 |
-
self.logger.debug(translations["process_output_file"])
|
89 |
-
|
90 |
-
if not isinstance(self.primary_source, np.ndarray):
|
91 |
-
self.logger.debug(translations["primary_source"])
|
92 |
-
self.primary_source = spec_utils.normalize(wave=source, max_peak=self.normalization_threshold).T
|
93 |
-
if not isinstance(self.secondary_source, np.ndarray):
|
94 |
-
self.logger.debug(translations["secondary_source"])
|
95 |
-
raw_mix = self.demix(mix, is_match_mix=True)
|
96 |
-
|
97 |
-
if self.invert_using_spec:
|
98 |
-
self.logger.debug(translations["invert_using_spec"])
|
99 |
-
self.secondary_source = spec_utils.invert_stem(raw_mix, source)
|
100 |
-
else:
|
101 |
-
self.logger.debug(translations["invert_using_spec_2"])
|
102 |
-
self.secondary_source = mix.T - source.T
|
103 |
-
|
104 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.secondary_stem_name.lower():
|
105 |
-
self.secondary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.secondary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
106 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.secondary_stem_name, stem_output_path=self.secondary_stem_output_path))
|
107 |
-
self.final_process(self.secondary_stem_output_path, self.secondary_source, self.secondary_stem_name)
|
108 |
-
output_files.append(self.secondary_stem_output_path)
|
109 |
-
|
110 |
-
if not self.output_single_stem or self.output_single_stem.lower() == self.primary_stem_name.lower():
|
111 |
-
self.primary_stem_output_path = os.path.join(f"{self.audio_file_base}_({self.primary_stem_name})_{self.model_name}.{self.output_format.lower()}")
|
112 |
-
|
113 |
-
if not isinstance(self.primary_source, np.ndarray): self.primary_source = source.T
|
114 |
-
|
115 |
-
self.logger.info(translations["save_secondary_stem_output_path"].format(stem_name=self.primary_stem_name, stem_output_path=self.primary_stem_output_path))
|
116 |
-
self.final_process(self.primary_stem_output_path, self.primary_source, self.primary_stem_name)
|
117 |
-
output_files.append(self.primary_stem_output_path)
|
118 |
-
|
119 |
-
return output_files
|
120 |
-
|
121 |
-
def initialize_model_settings(self):
|
122 |
-
self.logger.debug(translations["starting_model"])
|
123 |
-
|
124 |
-
self.n_bins = self.n_fft // 2 + 1
|
125 |
-
self.trim = self.n_fft // 2
|
126 |
-
|
127 |
-
self.chunk_size = self.hop_length * (self.segment_size - 1)
|
128 |
-
self.gen_size = self.chunk_size - 2 * self.trim
|
129 |
-
|
130 |
-
self.stft = STFT(self.logger, self.n_fft, self.hop_length, self.dim_f, self.torch_device)
|
131 |
-
|
132 |
-
self.logger.debug(f"{translations['input_info']}: n_fft = {self.n_fft} hop_length = {self.hop_length} dim_f = {self.dim_f}")
|
133 |
-
self.logger.debug(f"{translations['model_settings']}: n_bins = {self.n_bins}, Trim = {self.trim}, chunk_size = {self.chunk_size}, gen_size = {self.gen_size}")
|
134 |
-
|
135 |
-
def initialize_mix(self, mix, is_ckpt=False):
|
136 |
-
self.logger.debug(translations["initialize_mix"].format(is_ckpt=is_ckpt, shape=mix.shape))
|
137 |
-
|
138 |
-
if mix.shape[0] != 2:
|
139 |
-
error_message = translations["!=2"].format(shape=mix.shape[0])
|
140 |
-
self.logger.error(error_message)
|
141 |
-
raise ValueError(error_message)
|
142 |
-
|
143 |
-
if is_ckpt:
|
144 |
-
self.logger.debug(translations["process_check"])
|
145 |
-
pad = self.gen_size + self.trim - (mix.shape[-1] % self.gen_size)
|
146 |
-
self.logger.debug(f"{translations['cache']}: {pad}")
|
147 |
-
|
148 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
149 |
-
|
150 |
-
num_chunks = mixture.shape[-1] // self.gen_size
|
151 |
-
self.logger.debug(translations["shape"].format(shape=mixture.shape, num_chunks=num_chunks))
|
152 |
-
|
153 |
-
mix_waves = [mixture[:, i * self.gen_size : i * self.gen_size + self.chunk_size] for i in range(num_chunks)]
|
154 |
-
else:
|
155 |
-
self.logger.debug(translations["process_no_check"])
|
156 |
-
mix_waves = []
|
157 |
-
n_sample = mix.shape[1]
|
158 |
-
|
159 |
-
pad = self.gen_size - n_sample % self.gen_size
|
160 |
-
self.logger.debug(translations["n_sample_or_pad"].format(n_sample=n_sample, pad=pad))
|
161 |
-
|
162 |
-
mix_p = np.concatenate((np.zeros((2, self.trim)), mix, np.zeros((2, pad)), np.zeros((2, self.trim))), 1)
|
163 |
-
self.logger.debug(f"{translations['shape_2']}: {mix_p.shape}")
|
164 |
-
|
165 |
-
i = 0
|
166 |
-
while i < n_sample + pad:
|
167 |
-
waves = np.array(mix_p[:, i : i + self.chunk_size])
|
168 |
-
mix_waves.append(waves)
|
169 |
-
|
170 |
-
self.logger.debug(translations["process_part"].format(mix_waves=len(mix_waves), i=i, ii=i + self.chunk_size))
|
171 |
-
i += self.gen_size
|
172 |
-
|
173 |
-
mix_waves_tensor = torch.tensor(mix_waves, dtype=torch.float32).to(self.torch_device)
|
174 |
-
self.logger.debug(translations["mix_waves_to_tensor"].format(shape=mix_waves_tensor.shape))
|
175 |
-
|
176 |
-
return mix_waves_tensor, pad
|
177 |
-
|
178 |
-
def demix(self, mix, is_match_mix=False):
|
179 |
-
self.logger.debug(f"{translations['demix_is_match_mix']}: {is_match_mix}...")
|
180 |
-
self.initialize_model_settings()
|
181 |
-
|
182 |
-
org_mix = mix
|
183 |
-
self.logger.debug(f"{translations['mix_shape']}: {org_mix.shape}")
|
184 |
-
|
185 |
-
tar_waves_ = []
|
186 |
-
|
187 |
-
if is_match_mix:
|
188 |
-
chunk_size = self.hop_length * (self.segment_size - 1)
|
189 |
-
overlap = 0.02
|
190 |
-
self.logger.debug(translations["chunk_size_or_overlap"].format(chunk_size=chunk_size, overlap=overlap))
|
191 |
-
else:
|
192 |
-
chunk_size = self.chunk_size
|
193 |
-
overlap = self.overlap
|
194 |
-
self.logger.debug(translations["chunk_size_or_overlap_standard"].format(chunk_size=chunk_size, overlap=overlap))
|
195 |
-
|
196 |
-
|
197 |
-
gen_size = chunk_size - 2 * self.trim
|
198 |
-
self.logger.debug(f"{translations['calc_size']}: {gen_size}")
|
199 |
-
|
200 |
-
|
201 |
-
pad = gen_size + self.trim - ((mix.shape[-1]) % gen_size)
|
202 |
-
|
203 |
-
mixture = np.concatenate((np.zeros((2, self.trim), dtype="float32"), mix, np.zeros((2, pad), dtype="float32")), 1)
|
204 |
-
self.logger.debug(f"{translations['mix_cache']}: {mixture.shape}")
|
205 |
-
|
206 |
-
step = int((1 - overlap) * chunk_size)
|
207 |
-
self.logger.debug(translations["step_or_overlap"].format(step=step, overlap=overlap))
|
208 |
-
|
209 |
-
result = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
210 |
-
divider = np.zeros((1, 2, mixture.shape[-1]), dtype=np.float32)
|
211 |
-
|
212 |
-
total = 0
|
213 |
-
total_chunks = (mixture.shape[-1] + step - 1) // step
|
214 |
-
self.logger.debug(f"{translations['all_process_part']}: {total_chunks}")
|
215 |
-
|
216 |
-
for i in tqdm(range(0, mixture.shape[-1], step)):
|
217 |
-
total += 1
|
218 |
-
start = i
|
219 |
-
end = min(i + chunk_size, mixture.shape[-1])
|
220 |
-
self.logger.debug(translations["process_part_2"].format(total=total, total_chunks=total_chunks, start=start, end=end))
|
221 |
-
|
222 |
-
chunk_size_actual = end - start
|
223 |
-
window = None
|
224 |
-
|
225 |
-
if overlap != 0:
|
226 |
-
window = np.hanning(chunk_size_actual)
|
227 |
-
window = np.tile(window[None, None, :], (1, 2, 1))
|
228 |
-
self.logger.debug(translations["window"])
|
229 |
-
|
230 |
-
mix_part_ = mixture[:, start:end]
|
231 |
-
|
232 |
-
if end != i + chunk_size:
|
233 |
-
pad_size = (i + chunk_size) - end
|
234 |
-
mix_part_ = np.concatenate((mix_part_, np.zeros((2, pad_size), dtype="float32")), axis=-1)
|
235 |
-
|
236 |
-
mix_part = torch.tensor([mix_part_], dtype=torch.float32).to(self.torch_device)
|
237 |
-
|
238 |
-
mix_waves = mix_part.split(self.batch_size)
|
239 |
-
total_batches = len(mix_waves)
|
240 |
-
self.logger.debug(f"{translations['mix_or_batch']}: {total_batches}")
|
241 |
-
|
242 |
-
with torch.no_grad():
|
243 |
-
batches_processed = 0
|
244 |
-
for mix_wave in mix_waves:
|
245 |
-
batches_processed += 1
|
246 |
-
self.logger.debug(f"{translations['mix_wave']} {batches_processed}/{total_batches}")
|
247 |
-
|
248 |
-
tar_waves = self.run_model(mix_wave, is_match_mix=is_match_mix)
|
249 |
-
|
250 |
-
if window is not None:
|
251 |
-
tar_waves[..., :chunk_size_actual] *= window
|
252 |
-
divider[..., start:end] += window
|
253 |
-
else: divider[..., start:end] += 1
|
254 |
-
|
255 |
-
result[..., start:end] += tar_waves[..., : end - start]
|
256 |
-
|
257 |
-
|
258 |
-
self.logger.debug(translations["normalization_2"])
|
259 |
-
tar_waves = result / divider
|
260 |
-
tar_waves_.append(tar_waves)
|
261 |
-
|
262 |
-
tar_waves_ = np.vstack(tar_waves_)[:, :, self.trim : -self.trim]
|
263 |
-
tar_waves = np.concatenate(tar_waves_, axis=-1)[:, : mix.shape[-1]]
|
264 |
-
|
265 |
-
source = tar_waves[:, 0:None]
|
266 |
-
self.logger.debug(f"{translations['tar_waves']}: {tar_waves.shape}")
|
267 |
-
|
268 |
-
if not is_match_mix:
|
269 |
-
source *= self.compensate
|
270 |
-
self.logger.debug(translations["mix_match"])
|
271 |
-
|
272 |
-
self.logger.debug(translations["mix_success"])
|
273 |
-
return source
|
274 |
-
|
275 |
-
def run_model(self, mix, is_match_mix=False):
|
276 |
-
spek = self.stft(mix.to(self.torch_device))
|
277 |
-
self.logger.debug(translations["stft_2"].format(shape=spek.shape))
|
278 |
-
|
279 |
-
spek[:, :, :3, :] *= 0
|
280 |
-
|
281 |
-
if is_match_mix:
|
282 |
-
spec_pred = spek.cpu().numpy()
|
283 |
-
self.logger.debug(translations["is_match_mix"])
|
284 |
-
else:
|
285 |
-
if self.enable_denoise:
|
286 |
-
spec_pred_neg = self.model_run(-spek)
|
287 |
-
spec_pred_pos = self.model_run(spek)
|
288 |
-
spec_pred = (spec_pred_neg * -0.5) + (spec_pred_pos * 0.5)
|
289 |
-
self.logger.debug(translations["enable_denoise"])
|
290 |
-
else:
|
291 |
-
spec_pred = self.model_run(spek)
|
292 |
-
self.logger.debug(translations["no_denoise"])
|
293 |
-
|
294 |
-
result = self.stft.inverse(torch.tensor(spec_pred).to(self.torch_device)).cpu().detach().numpy()
|
295 |
-
self.logger.debug(f"{translations['stft']}: {result.shape}")
|
296 |
-
|
297 |
-
return result
|
298 |
-
|
299 |
-
class STFT:
|
300 |
-
def __init__(self, logger, n_fft, hop_length, dim_f, device):
|
301 |
-
self.logger = logger
|
302 |
-
self.n_fft = n_fft
|
303 |
-
self.hop_length = hop_length
|
304 |
-
self.dim_f = dim_f
|
305 |
-
self.device = device
|
306 |
-
self.hann_window = torch.hann_window(window_length=self.n_fft, periodic=True)
|
307 |
-
|
308 |
-
def __call__(self, input_tensor):
|
309 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
310 |
-
|
311 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
312 |
-
|
313 |
-
stft_window = self.hann_window.to(input_tensor.device)
|
314 |
-
|
315 |
-
batch_dimensions = input_tensor.shape[:-2]
|
316 |
-
channel_dim, time_dim = input_tensor.shape[-2:]
|
317 |
-
|
318 |
-
reshaped_tensor = input_tensor.reshape([-1, time_dim])
|
319 |
-
stft_output = torch.stft(reshaped_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True, return_complex=False)
|
320 |
-
|
321 |
-
permuted_stft_output = stft_output.permute([0, 3, 1, 2])
|
322 |
-
|
323 |
-
final_output = permuted_stft_output.reshape([*batch_dimensions, channel_dim, 2, -1, permuted_stft_output.shape[-1]]).reshape([*batch_dimensions, channel_dim * 2, -1, permuted_stft_output.shape[-1]])
|
324 |
-
|
325 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
326 |
-
|
327 |
-
return final_output[..., : self.dim_f, :]
|
328 |
-
|
329 |
-
def pad_frequency_dimension(self, input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins):
|
330 |
-
freq_padding = torch.zeros([*batch_dimensions, channel_dim, num_freq_bins - freq_dim, time_dim]).to(input_tensor.device)
|
331 |
-
padded_tensor = torch.cat([input_tensor, freq_padding], -2)
|
332 |
-
|
333 |
-
return padded_tensor
|
334 |
-
|
335 |
-
def calculate_inverse_dimensions(self, input_tensor):
|
336 |
-
batch_dimensions = input_tensor.shape[:-3]
|
337 |
-
channel_dim, freq_dim, time_dim = input_tensor.shape[-3:]
|
338 |
-
|
339 |
-
num_freq_bins = self.n_fft // 2 + 1
|
340 |
-
|
341 |
-
return batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins
|
342 |
-
|
343 |
-
def prepare_for_istft(self, padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim):
|
344 |
-
reshaped_tensor = padded_tensor.reshape([*batch_dimensions, channel_dim // 2, 2, num_freq_bins, time_dim])
|
345 |
-
flattened_tensor = reshaped_tensor.reshape([-1, 2, num_freq_bins, time_dim])
|
346 |
-
permuted_tensor = flattened_tensor.permute([0, 2, 3, 1])
|
347 |
-
complex_tensor = permuted_tensor[..., 0] + permuted_tensor[..., 1] * 1.0j
|
348 |
-
|
349 |
-
return complex_tensor
|
350 |
-
|
351 |
-
def inverse(self, input_tensor):
|
352 |
-
is_non_standard_device = not input_tensor.device.type in ["cuda", "cpu"]
|
353 |
-
|
354 |
-
if is_non_standard_device: input_tensor = input_tensor.cpu()
|
355 |
-
|
356 |
-
stft_window = self.hann_window.to(input_tensor.device)
|
357 |
-
|
358 |
-
batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins = self.calculate_inverse_dimensions(input_tensor)
|
359 |
-
|
360 |
-
padded_tensor = self.pad_frequency_dimension(input_tensor, batch_dimensions, channel_dim, freq_dim, time_dim, num_freq_bins)
|
361 |
-
|
362 |
-
complex_tensor = self.prepare_for_istft(padded_tensor, batch_dimensions, channel_dim, num_freq_bins, time_dim)
|
363 |
-
|
364 |
-
istft_result = torch.istft(complex_tensor, n_fft=self.n_fft, hop_length=self.hop_length, window=stft_window, center=True)
|
365 |
-
|
366 |
-
final_output = istft_result.reshape([*batch_dimensions, 2, -1])
|
367 |
-
|
368 |
-
if is_non_standard_device: final_output = final_output.to(self.device)
|
369 |
-
|
370 |
-
return final_output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/FCPE.py
DELETED
@@ -1,600 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import math
|
3 |
-
import torch
|
4 |
-
import librosa
|
5 |
-
import numpy as np
|
6 |
-
import torch.nn as nn
|
7 |
-
import soundfile as sf
|
8 |
-
import torch.utils.data
|
9 |
-
import torch.nn.functional as F
|
10 |
-
|
11 |
-
from torch import nn
|
12 |
-
from typing import Union
|
13 |
-
from functools import partial
|
14 |
-
from einops import rearrange, repeat
|
15 |
-
from torchaudio.transforms import Resample
|
16 |
-
from local_attention import LocalAttention
|
17 |
-
from librosa.filters import mel as librosa_mel_fn
|
18 |
-
from torch.nn.utils.parametrizations import weight_norm
|
19 |
-
|
20 |
-
os.environ["LRU_CACHE_CAPACITY"] = "3"
|
21 |
-
|
22 |
-
def load_wav_to_torch(full_path, target_sr=None, return_empty_on_exception=False):
|
23 |
-
try:
|
24 |
-
data, sample_rate = sf.read(full_path, always_2d=True)
|
25 |
-
except Exception as e:
|
26 |
-
print(f"{full_path}: {e}")
|
27 |
-
if return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
28 |
-
else: raise
|
29 |
-
|
30 |
-
data = data[:, 0] if len(data.shape) > 1 else data
|
31 |
-
assert len(data) > 2
|
32 |
-
|
33 |
-
max_mag = (-np.iinfo(data.dtype).min if np.issubdtype(data.dtype, np.integer) else max(np.amax(data), -np.amin(data)))
|
34 |
-
max_mag = ((2**31) + 1 if max_mag > (2**15) else ((2**15) + 1 if max_mag > 1.01 else 1.0))
|
35 |
-
data = torch.FloatTensor(data.astype(np.float32)) / max_mag
|
36 |
-
|
37 |
-
if (torch.isinf(data) | torch.isnan(data)).any() and return_empty_on_exception: return [], sample_rate or target_sr or 48000
|
38 |
-
if target_sr is not None and sample_rate != target_sr:
|
39 |
-
data = torch.from_numpy(librosa.core.resample(data.numpy(), orig_sr=sample_rate, target_sr=target_sr))
|
40 |
-
sample_rate = target_sr
|
41 |
-
|
42 |
-
return data, sample_rate
|
43 |
-
|
44 |
-
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
45 |
-
return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
|
46 |
-
def dynamic_range_decompression(x, C=1):
|
47 |
-
return np.exp(x) / C
|
48 |
-
def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
|
49 |
-
return torch.log(torch.clamp(x, min=clip_val) * C)
|
50 |
-
def dynamic_range_decompression_torch(x, C=1):
|
51 |
-
return torch.exp(x) / C
|
52 |
-
|
53 |
-
class STFT:
|
54 |
-
def __init__(self, sr=22050, n_mels=80, n_fft=1024, win_size=1024, hop_length=256, fmin=20, fmax=11025, clip_val=1e-5):
|
55 |
-
self.target_sr = sr
|
56 |
-
self.n_mels = n_mels
|
57 |
-
self.n_fft = n_fft
|
58 |
-
self.win_size = win_size
|
59 |
-
self.hop_length = hop_length
|
60 |
-
self.fmin = fmin
|
61 |
-
self.fmax = fmax
|
62 |
-
self.clip_val = clip_val
|
63 |
-
self.mel_basis = {}
|
64 |
-
self.hann_window = {}
|
65 |
-
|
66 |
-
def get_mel(self, y, keyshift=0, speed=1, center=False, train=False):
|
67 |
-
sample_rate = self.target_sr
|
68 |
-
n_mels = self.n_mels
|
69 |
-
n_fft = self.n_fft
|
70 |
-
win_size = self.win_size
|
71 |
-
hop_length = self.hop_length
|
72 |
-
fmin = self.fmin
|
73 |
-
fmax = self.fmax
|
74 |
-
clip_val = self.clip_val
|
75 |
-
|
76 |
-
factor = 2 ** (keyshift / 12)
|
77 |
-
n_fft_new = int(np.round(n_fft * factor))
|
78 |
-
win_size_new = int(np.round(win_size * factor))
|
79 |
-
hop_length_new = int(np.round(hop_length * speed))
|
80 |
-
|
81 |
-
mel_basis = self.mel_basis if not train else {}
|
82 |
-
hann_window = self.hann_window if not train else {}
|
83 |
-
mel_basis_key = str(fmax) + "_" + str(y.device)
|
84 |
-
|
85 |
-
if mel_basis_key not in mel_basis:
|
86 |
-
mel = librosa_mel_fn(sr=sample_rate, n_fft=n_fft, n_mels=n_mels, fmin=fmin, fmax=fmax)
|
87 |
-
mel_basis[mel_basis_key] = torch.from_numpy(mel).float().to(y.device)
|
88 |
-
|
89 |
-
keyshift_key = str(keyshift) + "_" + str(y.device)
|
90 |
-
|
91 |
-
if keyshift_key not in hann_window: hann_window[keyshift_key] = torch.hann_window(win_size_new).to(y.device)
|
92 |
-
|
93 |
-
pad_left = (win_size_new - hop_length_new) // 2
|
94 |
-
pad_right = max((win_size_new - hop_length_new + 1) // 2, win_size_new - y.size(-1) - pad_left)
|
95 |
-
mode = "reflect" if pad_right < y.size(-1) else "constant"
|
96 |
-
y = torch.nn.functional.pad(y.unsqueeze(1), (pad_left, pad_right), mode=mode)
|
97 |
-
y = y.squeeze(1)
|
98 |
-
|
99 |
-
spec = torch.stft(y, n_fft_new, hop_length=hop_length_new, win_length=win_size_new, window=hann_window[keyshift_key], center=center, pad_mode="reflect", normalized=False, onesided=True, return_complex=True)
|
100 |
-
spec = torch.sqrt(spec.real.pow(2) + spec.imag.pow(2) + (1e-9))
|
101 |
-
|
102 |
-
if keyshift != 0:
|
103 |
-
size = n_fft // 2 + 1
|
104 |
-
resize = spec.size(1)
|
105 |
-
spec = (F.pad(spec, (0, 0, 0, size - resize)) if resize < size else spec[:, :size, :])
|
106 |
-
spec = spec * win_size / win_size_new
|
107 |
-
spec = torch.matmul(mel_basis[mel_basis_key], spec)
|
108 |
-
spec = dynamic_range_compression_torch(spec, clip_val=clip_val)
|
109 |
-
return spec
|
110 |
-
|
111 |
-
def __call__(self, audiopath):
|
112 |
-
audio, sr = load_wav_to_torch(audiopath, target_sr=self.target_sr)
|
113 |
-
spect = self.get_mel(audio.unsqueeze(0)).squeeze(0)
|
114 |
-
return spect
|
115 |
-
|
116 |
-
stft = STFT()
|
117 |
-
|
118 |
-
def softmax_kernel(data, *, projection_matrix, is_query, normalize_data=True, eps=1e-4, device=None):
|
119 |
-
b, h, *_ = data.shape
|
120 |
-
|
121 |
-
data_normalizer = (data.shape[-1] ** -0.25) if normalize_data else 1.0
|
122 |
-
|
123 |
-
ratio = projection_matrix.shape[0] ** -0.5
|
124 |
-
projection = repeat(projection_matrix, "j d -> b h j d", b=b, h=h)
|
125 |
-
projection = projection.type_as(data)
|
126 |
-
data_dash = torch.einsum("...id,...jd->...ij", (data_normalizer * data), projection)
|
127 |
-
|
128 |
-
diag_data = data**2
|
129 |
-
diag_data = torch.sum(diag_data, dim=-1)
|
130 |
-
diag_data = (diag_data / 2.0) * (data_normalizer**2)
|
131 |
-
diag_data = diag_data.unsqueeze(dim=-1)
|
132 |
-
|
133 |
-
if is_query: data_dash = ratio * (torch.exp(data_dash - diag_data - torch.max(data_dash, dim=-1, keepdim=True).values) + eps)
|
134 |
-
else: data_dash = ratio * (torch.exp(data_dash - diag_data + eps))
|
135 |
-
|
136 |
-
return data_dash.type_as(data)
|
137 |
-
|
138 |
-
def orthogonal_matrix_chunk(cols, qr_uniform_q=False, device=None):
|
139 |
-
unstructured_block = torch.randn((cols, cols), device=device)
|
140 |
-
q, r = torch.linalg.qr(unstructured_block.cpu(), mode="reduced")
|
141 |
-
q, r = map(lambda t: t.to(device), (q, r))
|
142 |
-
|
143 |
-
if qr_uniform_q:
|
144 |
-
d = torch.diag(r, 0)
|
145 |
-
q *= d.sign()
|
146 |
-
|
147 |
-
return q.t()
|
148 |
-
|
149 |
-
def exists(val):
|
150 |
-
return val is not None
|
151 |
-
def empty(tensor):
|
152 |
-
return tensor.numel() == 0
|
153 |
-
def default(val, d):
|
154 |
-
return val if exists(val) else d
|
155 |
-
def cast_tuple(val):
|
156 |
-
return (val,) if not isinstance(val, tuple) else val
|
157 |
-
|
158 |
-
class PCmer(nn.Module):
|
159 |
-
def __init__(self, num_layers, num_heads, dim_model, dim_keys, dim_values, residual_dropout, attention_dropout):
|
160 |
-
super().__init__()
|
161 |
-
self.num_layers = num_layers
|
162 |
-
self.num_heads = num_heads
|
163 |
-
self.dim_model = dim_model
|
164 |
-
self.dim_values = dim_values
|
165 |
-
self.dim_keys = dim_keys
|
166 |
-
self.residual_dropout = residual_dropout
|
167 |
-
self.attention_dropout = attention_dropout
|
168 |
-
|
169 |
-
self._layers = nn.ModuleList([_EncoderLayer(self) for _ in range(num_layers)])
|
170 |
-
|
171 |
-
def forward(self, phone, mask=None):
|
172 |
-
for layer in self._layers:
|
173 |
-
phone = layer(phone, mask)
|
174 |
-
|
175 |
-
return phone
|
176 |
-
|
177 |
-
class _EncoderLayer(nn.Module):
|
178 |
-
def __init__(self, parent: PCmer):
|
179 |
-
super().__init__()
|
180 |
-
self.conformer = ConformerConvModule(parent.dim_model)
|
181 |
-
self.norm = nn.LayerNorm(parent.dim_model)
|
182 |
-
self.dropout = nn.Dropout(parent.residual_dropout)
|
183 |
-
self.attn = SelfAttention(dim=parent.dim_model, heads=parent.num_heads, causal=False)
|
184 |
-
|
185 |
-
def forward(self, phone, mask=None):
|
186 |
-
phone = phone + (self.attn(self.norm(phone), mask=mask))
|
187 |
-
phone = phone + (self.conformer(phone))
|
188 |
-
return phone
|
189 |
-
|
190 |
-
def calc_same_padding(kernel_size):
|
191 |
-
pad = kernel_size // 2
|
192 |
-
return (pad, pad - (kernel_size + 1) % 2)
|
193 |
-
|
194 |
-
class Swish(nn.Module):
|
195 |
-
def forward(self, x):
|
196 |
-
return x * x.sigmoid()
|
197 |
-
|
198 |
-
class Transpose(nn.Module):
|
199 |
-
def __init__(self, dims):
|
200 |
-
super().__init__()
|
201 |
-
assert len(dims) == 2, "dims == 2"
|
202 |
-
self.dims = dims
|
203 |
-
|
204 |
-
def forward(self, x):
|
205 |
-
return x.transpose(*self.dims)
|
206 |
-
|
207 |
-
class GLU(nn.Module):
|
208 |
-
def __init__(self, dim):
|
209 |
-
super().__init__()
|
210 |
-
self.dim = dim
|
211 |
-
|
212 |
-
def forward(self, x):
|
213 |
-
out, gate = x.chunk(2, dim=self.dim)
|
214 |
-
return out * gate.sigmoid()
|
215 |
-
|
216 |
-
class DepthWiseConv1d(nn.Module):
|
217 |
-
def __init__(self, chan_in, chan_out, kernel_size, padding):
|
218 |
-
super().__init__()
|
219 |
-
self.padding = padding
|
220 |
-
self.conv = nn.Conv1d(chan_in, chan_out, kernel_size, groups=chan_in)
|
221 |
-
|
222 |
-
def forward(self, x):
|
223 |
-
x = F.pad(x, self.padding)
|
224 |
-
return self.conv(x)
|
225 |
-
|
226 |
-
class ConformerConvModule(nn.Module):
|
227 |
-
def __init__(self, dim, causal=False, expansion_factor=2, kernel_size=31, dropout=0.0):
|
228 |
-
super().__init__()
|
229 |
-
|
230 |
-
inner_dim = dim * expansion_factor
|
231 |
-
padding = calc_same_padding(kernel_size) if not causal else (kernel_size - 1, 0)
|
232 |
-
|
233 |
-
self.net = nn.Sequential(nn.LayerNorm(dim), Transpose((1, 2)), nn.Conv1d(dim, inner_dim * 2, 1), GLU(dim=1), DepthWiseConv1d(inner_dim, inner_dim, kernel_size=kernel_size, padding=padding), Swish(), nn.Conv1d(inner_dim, dim, 1), Transpose((1, 2)), nn.Dropout(dropout))
|
234 |
-
|
235 |
-
def forward(self, x):
|
236 |
-
return self.net(x)
|
237 |
-
|
238 |
-
def linear_attention(q, k, v):
|
239 |
-
if v is None:
|
240 |
-
out = torch.einsum("...ed,...nd->...ne", k, q)
|
241 |
-
return out
|
242 |
-
else:
|
243 |
-
k_cumsum = k.sum(dim=-2)
|
244 |
-
D_inv = 1.0 / (torch.einsum("...nd,...d->...n", q, k_cumsum.type_as(q)) + 1e-8)
|
245 |
-
context = torch.einsum("...nd,...ne->...de", k, v)
|
246 |
-
out = torch.einsum("...de,...nd,...n->...ne", context, q, D_inv)
|
247 |
-
return out
|
248 |
-
|
249 |
-
def gaussian_orthogonal_random_matrix(nb_rows, nb_columns, scaling=0, qr_uniform_q=False, device=None):
|
250 |
-
nb_full_blocks = int(nb_rows / nb_columns)
|
251 |
-
block_list = []
|
252 |
-
|
253 |
-
for _ in range(nb_full_blocks):
|
254 |
-
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
|
255 |
-
block_list.append(q)
|
256 |
-
|
257 |
-
remaining_rows = nb_rows - nb_full_blocks * nb_columns
|
258 |
-
|
259 |
-
if remaining_rows > 0:
|
260 |
-
q = orthogonal_matrix_chunk(nb_columns, qr_uniform_q=qr_uniform_q, device=device)
|
261 |
-
block_list.append(q[:remaining_rows])
|
262 |
-
|
263 |
-
final_matrix = torch.cat(block_list)
|
264 |
-
|
265 |
-
if scaling == 0: multiplier = torch.randn((nb_rows, nb_columns), device=device).norm(dim=1)
|
266 |
-
elif scaling == 1: multiplier = math.sqrt((float(nb_columns))) * torch.ones((nb_rows,), device=device)
|
267 |
-
else: raise ValueError(f"Chia không hợp lệ {scaling}")
|
268 |
-
|
269 |
-
return torch.diag(multiplier) @ final_matrix
|
270 |
-
|
271 |
-
class FastAttention(nn.Module):
|
272 |
-
def __init__(self, dim_heads, nb_features=None, ortho_scaling=0, causal=False, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, no_projection=False):
|
273 |
-
super().__init__()
|
274 |
-
nb_features = default(nb_features, int(dim_heads * math.log(dim_heads)))
|
275 |
-
|
276 |
-
self.dim_heads = dim_heads
|
277 |
-
self.nb_features = nb_features
|
278 |
-
self.ortho_scaling = ortho_scaling
|
279 |
-
|
280 |
-
self.create_projection = partial(gaussian_orthogonal_random_matrix, nb_rows=self.nb_features, nb_columns=dim_heads, scaling=ortho_scaling, qr_uniform_q=qr_uniform_q)
|
281 |
-
projection_matrix = self.create_projection()
|
282 |
-
self.register_buffer("projection_matrix", projection_matrix)
|
283 |
-
|
284 |
-
self.generalized_attention = generalized_attention
|
285 |
-
self.kernel_fn = kernel_fn
|
286 |
-
self.no_projection = no_projection
|
287 |
-
self.causal = causal
|
288 |
-
|
289 |
-
@torch.no_grad()
|
290 |
-
def redraw_projection_matrix(self):
|
291 |
-
projections = self.create_projection()
|
292 |
-
self.projection_matrix.copy_(projections)
|
293 |
-
del projections
|
294 |
-
|
295 |
-
def forward(self, q, k, v):
|
296 |
-
device = q.device
|
297 |
-
|
298 |
-
if self.no_projection:
|
299 |
-
q = q.softmax(dim=-1)
|
300 |
-
k = torch.exp(k) if self.causal else k.softmax(dim=-2)
|
301 |
-
else:
|
302 |
-
create_kernel = partial(softmax_kernel, projection_matrix=self.projection_matrix, device=device)
|
303 |
-
|
304 |
-
q = create_kernel(q, is_query=True)
|
305 |
-
k = create_kernel(k, is_query=False)
|
306 |
-
|
307 |
-
attn_fn = linear_attention if not self.causal else self.causal_linear_fn
|
308 |
-
|
309 |
-
if v is None:
|
310 |
-
out = attn_fn(q, k, None)
|
311 |
-
return out
|
312 |
-
else:
|
313 |
-
out = attn_fn(q, k, v)
|
314 |
-
return out
|
315 |
-
|
316 |
-
class SelfAttention(nn.Module):
|
317 |
-
def __init__(self, dim, causal=False, heads=8, dim_head=64, local_heads=0, local_window_size=256, nb_features=None, feature_redraw_interval=1000, generalized_attention=False, kernel_fn=nn.ReLU(), qr_uniform_q=False, dropout=0.0, no_projection=False):
|
318 |
-
super().__init__()
|
319 |
-
assert dim % heads == 0
|
320 |
-
dim_head = default(dim_head, dim // heads)
|
321 |
-
inner_dim = dim_head * heads
|
322 |
-
self.fast_attention = FastAttention(dim_head, nb_features, causal=causal, generalized_attention=generalized_attention, kernel_fn=kernel_fn, qr_uniform_q=qr_uniform_q, no_projection=no_projection)
|
323 |
-
self.heads = heads
|
324 |
-
self.global_heads = heads - local_heads
|
325 |
-
self.local_attn = (LocalAttention(window_size=local_window_size, causal=causal, autopad=True, dropout=dropout, look_forward=int(not causal), rel_pos_emb_config=(dim_head, local_heads)) if local_heads > 0 else None)
|
326 |
-
self.to_q = nn.Linear(dim, inner_dim)
|
327 |
-
self.to_k = nn.Linear(dim, inner_dim)
|
328 |
-
self.to_v = nn.Linear(dim, inner_dim)
|
329 |
-
self.to_out = nn.Linear(inner_dim, dim)
|
330 |
-
self.dropout = nn.Dropout(dropout)
|
331 |
-
|
332 |
-
@torch.no_grad()
|
333 |
-
def redraw_projection_matrix(self):
|
334 |
-
self.fast_attention.redraw_projection_matrix()
|
335 |
-
|
336 |
-
def forward(self, x, context=None, mask=None, context_mask=None, name=None, inference=False, **kwargs):
|
337 |
-
_, _, _, h, gh = *x.shape, self.heads, self.global_heads
|
338 |
-
|
339 |
-
cross_attend = exists(context)
|
340 |
-
context = default(context, x)
|
341 |
-
context_mask = default(context_mask, mask) if not cross_attend else context_mask
|
342 |
-
q, k, v = self.to_q(x), self.to_k(context), self.to_v(context)
|
343 |
-
|
344 |
-
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), (q, k, v))
|
345 |
-
(q, lq), (k, lk), (v, lv) = map(lambda t: (t[:, :gh], t[:, gh:]), (q, k, v))
|
346 |
-
|
347 |
-
attn_outs = []
|
348 |
-
|
349 |
-
if not empty(q):
|
350 |
-
if exists(context_mask):
|
351 |
-
global_mask = context_mask[:, None, :, None]
|
352 |
-
v.masked_fill_(~global_mask, 0.0)
|
353 |
-
|
354 |
-
if cross_attend: pass
|
355 |
-
else: out = self.fast_attention(q, k, v)
|
356 |
-
|
357 |
-
attn_outs.append(out)
|
358 |
-
|
359 |
-
if not empty(lq):
|
360 |
-
assert (not cross_attend), "not cross_attend"
|
361 |
-
out = self.local_attn(lq, lk, lv, input_mask=mask)
|
362 |
-
attn_outs.append(out)
|
363 |
-
|
364 |
-
out = torch.cat(attn_outs, dim=1)
|
365 |
-
out = rearrange(out, "b h n d -> b n (h d)")
|
366 |
-
out = self.to_out(out)
|
367 |
-
return self.dropout(out)
|
368 |
-
|
369 |
-
def l2_regularization(model, l2_alpha):
|
370 |
-
l2_loss = []
|
371 |
-
for module in model.modules():
|
372 |
-
if type(module) is nn.Conv2d: l2_loss.append((module.weight**2).sum() / 2.0)
|
373 |
-
|
374 |
-
return l2_alpha * sum(l2_loss)
|
375 |
-
|
376 |
-
class _FCPE(nn.Module):
|
377 |
-
def __init__(self, input_channel=128, out_dims=360, n_layers=12, n_chans=512, use_siren=False, use_full=False, loss_mse_scale=10, loss_l2_regularization=False, loss_l2_regularization_scale=1, loss_grad1_mse=False, loss_grad1_mse_scale=1, f0_max=1975.5, f0_min=32.70, confidence=False, threshold=0.05, use_input_conv=True):
|
378 |
-
super().__init__()
|
379 |
-
if use_siren: raise ValueError("Siren not support")
|
380 |
-
if use_full: raise ValueError("Model full not support")
|
381 |
-
|
382 |
-
self.loss_mse_scale = loss_mse_scale if (loss_mse_scale is not None) else 10
|
383 |
-
self.loss_l2_regularization = (loss_l2_regularization if (loss_l2_regularization is not None) else False)
|
384 |
-
self.loss_l2_regularization_scale = (loss_l2_regularization_scale if (loss_l2_regularization_scale is not None) else 1)
|
385 |
-
self.loss_grad1_mse = loss_grad1_mse if (loss_grad1_mse is not None) else False
|
386 |
-
self.loss_grad1_mse_scale = (loss_grad1_mse_scale if (loss_grad1_mse_scale is not None) else 1)
|
387 |
-
self.f0_max = f0_max if (f0_max is not None) else 1975.5
|
388 |
-
self.f0_min = f0_min if (f0_min is not None) else 32.70
|
389 |
-
self.confidence = confidence if (confidence is not None) else False
|
390 |
-
self.threshold = threshold if (threshold is not None) else 0.05
|
391 |
-
self.use_input_conv = use_input_conv if (use_input_conv is not None) else True
|
392 |
-
self.cent_table_b = torch.Tensor(np.linspace(self.f0_to_cent(torch.Tensor([f0_min]))[0], self.f0_to_cent(torch.Tensor([f0_max]))[0], out_dims))
|
393 |
-
self.register_buffer("cent_table", self.cent_table_b)
|
394 |
-
|
395 |
-
_leaky = nn.LeakyReLU()
|
396 |
-
|
397 |
-
self.stack = nn.Sequential(nn.Conv1d(input_channel, n_chans, 3, 1, 1), nn.GroupNorm(4, n_chans), _leaky, nn.Conv1d(n_chans, n_chans, 3, 1, 1))
|
398 |
-
self.decoder = PCmer(num_layers=n_layers, num_heads=8, dim_model=n_chans, dim_keys=n_chans, dim_values=n_chans, residual_dropout=0.1, attention_dropout=0.1)
|
399 |
-
self.norm = nn.LayerNorm(n_chans)
|
400 |
-
self.n_out = out_dims
|
401 |
-
self.dense_out = weight_norm(nn.Linear(n_chans, self.n_out))
|
402 |
-
|
403 |
-
def forward(self, mel, infer=True, gt_f0=None, return_hz_f0=False, cdecoder="local_argmax"):
|
404 |
-
if cdecoder == "argmax": self.cdecoder = self.cents_decoder
|
405 |
-
elif cdecoder == "local_argmax": self.cdecoder = self.cents_local_decoder
|
406 |
-
|
407 |
-
x = (self.stack(mel.transpose(1, 2)).transpose(1, 2) if self.use_input_conv else mel)
|
408 |
-
x = self.decoder(x)
|
409 |
-
x = self.norm(x)
|
410 |
-
x = self.dense_out(x)
|
411 |
-
x = torch.sigmoid(x)
|
412 |
-
|
413 |
-
if not infer:
|
414 |
-
gt_cent_f0 = self.f0_to_cent(gt_f0)
|
415 |
-
gt_cent_f0 = self.gaussian_blurred_cent(gt_cent_f0)
|
416 |
-
loss_all = self.loss_mse_scale * F.binary_cross_entropy(x, gt_cent_f0)
|
417 |
-
|
418 |
-
if self.loss_l2_regularization: loss_all = loss_all + l2_regularization(model=self, l2_alpha=self.loss_l2_regularization_scale)
|
419 |
-
|
420 |
-
x = loss_all
|
421 |
-
|
422 |
-
if infer:
|
423 |
-
x = self.cdecoder(x)
|
424 |
-
x = self.cent_to_f0(x)
|
425 |
-
x = (1 + x / 700).log() if not return_hz_f0 else x
|
426 |
-
|
427 |
-
return x
|
428 |
-
|
429 |
-
def cents_decoder(self, y, mask=True):
|
430 |
-
B, N, _ = y.size()
|
431 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
432 |
-
rtn = torch.sum(ci * y, dim=-1, keepdim=True) / torch.sum(y, dim=-1, keepdim=True)
|
433 |
-
if mask:
|
434 |
-
confident = torch.max(y, dim=-1, keepdim=True)[0]
|
435 |
-
confident_mask = torch.ones_like(confident)
|
436 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
437 |
-
rtn = rtn * confident_mask
|
438 |
-
|
439 |
-
return (rtn, confident) if self.confidence else rtn
|
440 |
-
|
441 |
-
def cents_local_decoder(self, y, mask=True):
|
442 |
-
B, N, _ = y.size()
|
443 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
444 |
-
confident, max_index = torch.max(y, dim=-1, keepdim=True)
|
445 |
-
local_argmax_index = torch.arange(0, 9).to(max_index.device) + (max_index - 4)
|
446 |
-
local_argmax_index = torch.clamp(local_argmax_index, 0, self.n_out - 1)
|
447 |
-
ci_l = torch.gather(ci, -1, local_argmax_index)
|
448 |
-
y_l = torch.gather(y, -1, local_argmax_index)
|
449 |
-
rtn = torch.sum(ci_l * y_l, dim=-1, keepdim=True) / torch.sum(y_l, dim=-1, keepdim=True)
|
450 |
-
|
451 |
-
if mask:
|
452 |
-
confident_mask = torch.ones_like(confident)
|
453 |
-
confident_mask[confident <= self.threshold] = float("-INF")
|
454 |
-
rtn = rtn * confident_mask
|
455 |
-
|
456 |
-
return (rtn, confident) if self.confidence else rtn
|
457 |
-
|
458 |
-
def cent_to_f0(self, cent):
|
459 |
-
return 10.0 * 2 ** (cent / 1200.0)
|
460 |
-
|
461 |
-
def f0_to_cent(self, f0):
|
462 |
-
return 1200.0 * torch.log2(f0 / 10.0)
|
463 |
-
|
464 |
-
def gaussian_blurred_cent(self, cents):
|
465 |
-
mask = (cents > 0.1) & (cents < (1200.0 * np.log2(self.f0_max / 10.0)))
|
466 |
-
B, N, _ = cents.size()
|
467 |
-
ci = self.cent_table[None, None, :].expand(B, N, -1)
|
468 |
-
return torch.exp(-torch.square(ci - cents) / 1250) * mask.float()
|
469 |
-
|
470 |
-
class FCPEInfer:
|
471 |
-
def __init__(self, model_path, device=None, dtype=torch.float32):
|
472 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
473 |
-
|
474 |
-
self.device = device
|
475 |
-
ckpt = torch.load(model_path, map_location=torch.device(self.device))
|
476 |
-
self.args = DotDict(ckpt["config"])
|
477 |
-
self.dtype = dtype
|
478 |
-
model = _FCPE(input_channel=self.args.model.input_channel, out_dims=self.args.model.out_dims, n_layers=self.args.model.n_layers, n_chans=self.args.model.n_chans, use_siren=self.args.model.use_siren, use_full=self.args.model.use_full, loss_mse_scale=self.args.loss.loss_mse_scale, loss_l2_regularization=self.args.loss.loss_l2_regularization, loss_l2_regularization_scale=self.args.loss.loss_l2_regularization_scale, loss_grad1_mse=self.args.loss.loss_grad1_mse, loss_grad1_mse_scale=self.args.loss.loss_grad1_mse_scale, f0_max=self.args.model.f0_max, f0_min=self.args.model.f0_min, confidence=self.args.model.confidence)
|
479 |
-
model.to(self.device).to(self.dtype)
|
480 |
-
model.load_state_dict(ckpt["model"])
|
481 |
-
model.eval()
|
482 |
-
self.model = model
|
483 |
-
self.wav2mel = Wav2Mel(self.args, dtype=self.dtype, device=self.device)
|
484 |
-
|
485 |
-
@torch.no_grad()
|
486 |
-
def __call__(self, audio, sr, threshold=0.05):
|
487 |
-
self.model.threshold = threshold
|
488 |
-
audio = audio[None, :]
|
489 |
-
mel = self.wav2mel(audio=audio, sample_rate=sr).to(self.dtype)
|
490 |
-
f0 = self.model(mel=mel, infer=True, return_hz_f0=True)
|
491 |
-
return f0
|
492 |
-
|
493 |
-
class Wav2Mel:
|
494 |
-
def __init__(self, args, device=None, dtype=torch.float32):
|
495 |
-
self.sample_rate = args.mel.sampling_rate
|
496 |
-
self.hop_size = args.mel.hop_size
|
497 |
-
|
498 |
-
if device is None: device = "cuda" if torch.cuda.is_available() else "cpu"
|
499 |
-
|
500 |
-
self.device = device
|
501 |
-
self.dtype = dtype
|
502 |
-
self.stft = STFT(args.mel.sampling_rate, args.mel.num_mels, args.mel.n_fft, args.mel.win_size, args.mel.hop_size, args.mel.fmin, args.mel.fmax)
|
503 |
-
self.resample_kernel = {}
|
504 |
-
|
505 |
-
def extract_nvstft(self, audio, keyshift=0, train=False):
|
506 |
-
mel = self.stft.get_mel(audio, keyshift=keyshift, train=train).transpose(1, 2)
|
507 |
-
return mel
|
508 |
-
|
509 |
-
def extract_mel(self, audio, sample_rate, keyshift=0, train=False):
|
510 |
-
audio = audio.to(self.dtype).to(self.device)
|
511 |
-
|
512 |
-
if sample_rate == self.sample_rate: audio_res = audio
|
513 |
-
else:
|
514 |
-
key_str = str(sample_rate)
|
515 |
-
|
516 |
-
if key_str not in self.resample_kernel: self.resample_kernel[key_str] = Resample(sample_rate, self.sample_rate, lowpass_filter_width=128)
|
517 |
-
|
518 |
-
self.resample_kernel[key_str] = (self.resample_kernel[key_str].to(self.dtype).to(self.device))
|
519 |
-
audio_res = self.resample_kernel[key_str](audio)
|
520 |
-
|
521 |
-
mel = self.extract_nvstft(audio_res, keyshift=keyshift, train=train)
|
522 |
-
n_frames = int(audio.shape[1] // self.hop_size) + 1
|
523 |
-
mel = (torch.cat((mel, mel[:, -1:, :]), 1) if n_frames > int(mel.shape[1]) else mel)
|
524 |
-
mel = mel[:, :n_frames, :] if n_frames < int(mel.shape[1]) else mel
|
525 |
-
return mel
|
526 |
-
|
527 |
-
def __call__(self, audio, sample_rate, keyshift=0, train=False):
|
528 |
-
return self.extract_mel(audio, sample_rate, keyshift=keyshift, train=train)
|
529 |
-
|
530 |
-
class DotDict(dict):
|
531 |
-
def __getattr__(*args):
|
532 |
-
val = dict.get(*args)
|
533 |
-
return DotDict(val) if type(val) is dict else val
|
534 |
-
|
535 |
-
__setattr__ = dict.__setitem__
|
536 |
-
__delattr__ = dict.__delitem__
|
537 |
-
|
538 |
-
class F0Predictor(object):
|
539 |
-
def compute_f0(self, wav, p_len): pass
|
540 |
-
def compute_f0_uv(self, wav, p_len): pass
|
541 |
-
|
542 |
-
class FCPE(F0Predictor):
|
543 |
-
def __init__(self, model_path, hop_length=512, f0_min=50, f0_max=1100, dtype=torch.float32, device=None, sample_rate=44100, threshold=0.05):
|
544 |
-
self.fcpe = FCPEInfer(model_path, device=device, dtype=dtype)
|
545 |
-
self.hop_length = hop_length
|
546 |
-
self.f0_min = f0_min
|
547 |
-
self.f0_max = f0_max
|
548 |
-
self.device = device or ("cuda" if torch.cuda.is_available() else "cpu")
|
549 |
-
self.threshold = threshold
|
550 |
-
self.sample_rate = sample_rate
|
551 |
-
self.dtype = dtype
|
552 |
-
self.name = "fcpe"
|
553 |
-
|
554 |
-
def repeat_expand(self, content: Union[torch.Tensor, np.ndarray], target_len, mode = "nearest"):
|
555 |
-
ndim = content.ndim
|
556 |
-
content = (content[None, None] if ndim == 1 else content[None] if ndim == 2 else content)
|
557 |
-
assert content.ndim == 3
|
558 |
-
is_np = isinstance(content, np.ndarray)
|
559 |
-
content = torch.from_numpy(content) if is_np else content
|
560 |
-
results = torch.nn.functional.interpolate(content, size=target_len, mode=mode)
|
561 |
-
results = results.numpy() if is_np else results
|
562 |
-
return results[0, 0] if ndim == 1 else results[0] if ndim == 2 else results
|
563 |
-
|
564 |
-
def post_process(self, x, sample_rate, f0, pad_to):
|
565 |
-
f0 = (torch.from_numpy(f0).float().to(x.device) if isinstance(f0, np.ndarray) else f0)
|
566 |
-
f0 = self.repeat_expand(f0, pad_to) if pad_to is not None else f0
|
567 |
-
|
568 |
-
vuv_vector = torch.zeros_like(f0)
|
569 |
-
vuv_vector[f0 > 0.0] = 1.0
|
570 |
-
vuv_vector[f0 <= 0.0] = 0.0
|
571 |
-
|
572 |
-
nzindex = torch.nonzero(f0).squeeze()
|
573 |
-
f0 = torch.index_select(f0, dim=0, index=nzindex).cpu().numpy()
|
574 |
-
time_org = self.hop_length / sample_rate * nzindex.cpu().numpy()
|
575 |
-
time_frame = np.arange(pad_to) * self.hop_length / sample_rate
|
576 |
-
|
577 |
-
vuv_vector = F.interpolate(vuv_vector[None, None, :], size=pad_to)[0][0]
|
578 |
-
|
579 |
-
if f0.shape[0] <= 0: return np.zeros(pad_to), vuv_vector.cpu().numpy()
|
580 |
-
if f0.shape[0] == 1: return np.ones(pad_to) * f0[0], vuv_vector.cpu().numpy()
|
581 |
-
|
582 |
-
f0 = np.interp(time_frame, time_org, f0, left=f0[0], right=f0[-1])
|
583 |
-
return f0, vuv_vector.cpu().numpy()
|
584 |
-
|
585 |
-
def compute_f0(self, wav, p_len=None):
|
586 |
-
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
587 |
-
p_len = x.shape[0] // self.hop_length if p_len is None else p_len
|
588 |
-
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
|
589 |
-
|
590 |
-
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
591 |
-
|
592 |
-
return self.post_process(x, self.sample_rate, f0, p_len)[0]
|
593 |
-
def compute_f0_uv(self, wav, p_len=None):
|
594 |
-
x = torch.FloatTensor(wav).to(self.dtype).to(self.device)
|
595 |
-
p_len = x.shape[0] // self.hop_length if p_len is None else p_len
|
596 |
-
f0 = self.fcpe(x, sr=self.sample_rate, threshold=self.threshold)[0, :, 0]
|
597 |
-
|
598 |
-
if torch.all(f0 == 0): return f0.cpu().numpy() if p_len is None else np.zeros(p_len), (f0.cpu().numpy() if p_len is None else np.zeros(p_len))
|
599 |
-
|
600 |
-
return self.post_process(x, self.sample_rate, f0, p_len)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/predictors/RMVPE.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import numpy as np
|
3 |
-
import torch.nn as nn
|
4 |
-
import torch.nn.functional as F
|
5 |
-
|
6 |
-
from typing import List
|
7 |
-
from librosa.filters import mel
|
8 |
-
|
9 |
-
N_MELS = 128
|
10 |
-
N_CLASS = 360
|
11 |
-
|
12 |
-
|
13 |
-
class ConvBlockRes(nn.Module):
|
14 |
-
def __init__(self, in_channels, out_channels, momentum=0.01):
|
15 |
-
super(ConvBlockRes, self).__init__()
|
16 |
-
self.conv = nn.Sequential(nn.Conv2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU(), nn.Conv2d(in_channels=out_channels, out_channels=out_channels, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
17 |
-
|
18 |
-
if in_channels != out_channels:
|
19 |
-
self.shortcut = nn.Conv2d(in_channels, out_channels, (1, 1))
|
20 |
-
self.is_shortcut = True
|
21 |
-
else: self.is_shortcut = False
|
22 |
-
|
23 |
-
def forward(self, x):
|
24 |
-
if self.is_shortcut: return self.conv(x) + self.shortcut(x)
|
25 |
-
else: return self.conv(x) + x
|
26 |
-
|
27 |
-
class ResEncoderBlock(nn.Module):
|
28 |
-
def __init__(self, in_channels, out_channels, kernel_size, n_blocks=1, momentum=0.01):
|
29 |
-
super(ResEncoderBlock, self).__init__()
|
30 |
-
self.n_blocks = n_blocks
|
31 |
-
self.conv = nn.ModuleList()
|
32 |
-
self.conv.append(ConvBlockRes(in_channels, out_channels, momentum))
|
33 |
-
|
34 |
-
for _ in range(n_blocks - 1):
|
35 |
-
self.conv.append(ConvBlockRes(out_channels, out_channels, momentum))
|
36 |
-
|
37 |
-
self.kernel_size = kernel_size
|
38 |
-
|
39 |
-
if self.kernel_size is not None:
|
40 |
-
self.pool = nn.AvgPool2d(kernel_size=kernel_size)
|
41 |
-
|
42 |
-
def forward(self, x):
|
43 |
-
for i in range(self.n_blocks):
|
44 |
-
x = self.conv[i](x)
|
45 |
-
|
46 |
-
if self.kernel_size is not None: return x, self.pool(x)
|
47 |
-
else: return x
|
48 |
-
|
49 |
-
class Encoder(nn.Module):
|
50 |
-
def __init__(self, in_channels, in_size, n_encoders, kernel_size, n_blocks, out_channels=16, momentum=0.01):
|
51 |
-
super(Encoder, self).__init__()
|
52 |
-
self.n_encoders = n_encoders
|
53 |
-
self.bn = nn.BatchNorm2d(in_channels, momentum=momentum)
|
54 |
-
self.layers = nn.ModuleList()
|
55 |
-
self.latent_channels = []
|
56 |
-
|
57 |
-
for i in range(self.n_encoders):
|
58 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, kernel_size, n_blocks, momentum=momentum))
|
59 |
-
self.latent_channels.append([out_channels, in_size])
|
60 |
-
in_channels = out_channels
|
61 |
-
out_channels *= 2
|
62 |
-
in_size //= 2
|
63 |
-
|
64 |
-
self.out_size = in_size
|
65 |
-
self.out_channel = out_channels
|
66 |
-
|
67 |
-
def forward(self, x: torch.Tensor):
|
68 |
-
concat_tensors: List[torch.Tensor] = []
|
69 |
-
x = self.bn(x)
|
70 |
-
|
71 |
-
for i in range(self.n_encoders):
|
72 |
-
t, x = self.layers[i](x)
|
73 |
-
concat_tensors.append(t)
|
74 |
-
|
75 |
-
return x, concat_tensors
|
76 |
-
|
77 |
-
class Intermediate(nn.Module):
|
78 |
-
def __init__(self, in_channels, out_channels, n_inters, n_blocks, momentum=0.01):
|
79 |
-
super(Intermediate, self).__init__()
|
80 |
-
self.n_inters = n_inters
|
81 |
-
self.layers = nn.ModuleList()
|
82 |
-
self.layers.append(ResEncoderBlock(in_channels, out_channels, None, n_blocks, momentum))
|
83 |
-
|
84 |
-
for _ in range(self.n_inters - 1):
|
85 |
-
self.layers.append(ResEncoderBlock(out_channels, out_channels, None, n_blocks, momentum))
|
86 |
-
|
87 |
-
def forward(self, x):
|
88 |
-
for i in range(self.n_inters):
|
89 |
-
x = self.layers[i](x)
|
90 |
-
return x
|
91 |
-
|
92 |
-
class ResDecoderBlock(nn.Module):
|
93 |
-
def __init__(self, in_channels, out_channels, stride, n_blocks=1, momentum=0.01):
|
94 |
-
super(ResDecoderBlock, self).__init__()
|
95 |
-
out_padding = (0, 1) if stride == (1, 2) else (1, 1)
|
96 |
-
self.n_blocks = n_blocks
|
97 |
-
self.conv1 = nn.Sequential(nn.ConvTranspose2d(in_channels=in_channels, out_channels=out_channels, kernel_size=(3, 3), stride=stride, padding=(1, 1), output_padding=out_padding, bias=False), nn.BatchNorm2d(out_channels, momentum=momentum), nn.ReLU())
|
98 |
-
self.conv2 = nn.ModuleList()
|
99 |
-
self.conv2.append(ConvBlockRes(out_channels * 2, out_channels, momentum))
|
100 |
-
for _ in range(n_blocks - 1):
|
101 |
-
self.conv2.append(ConvBlockRes(out_channels, out_channels, momentum))
|
102 |
-
|
103 |
-
def forward(self, x, concat_tensor):
|
104 |
-
x = self.conv1(x)
|
105 |
-
x = torch.cat((x, concat_tensor), dim=1)
|
106 |
-
for i in range(self.n_blocks):
|
107 |
-
x = self.conv2[i](x)
|
108 |
-
|
109 |
-
return x
|
110 |
-
|
111 |
-
class Decoder(nn.Module):
|
112 |
-
def __init__(self, in_channels, n_decoders, stride, n_blocks, momentum=0.01):
|
113 |
-
super(Decoder, self).__init__()
|
114 |
-
self.layers = nn.ModuleList()
|
115 |
-
self.n_decoders = n_decoders
|
116 |
-
|
117 |
-
for _ in range(self.n_decoders):
|
118 |
-
out_channels = in_channels // 2
|
119 |
-
self.layers.append(ResDecoderBlock(in_channels, out_channels, stride, n_blocks, momentum))
|
120 |
-
in_channels = out_channels
|
121 |
-
|
122 |
-
def forward(self, x, concat_tensors):
|
123 |
-
for i in range(self.n_decoders):
|
124 |
-
x = self.layers[i](x, concat_tensors[-1 - i])
|
125 |
-
|
126 |
-
return x
|
127 |
-
|
128 |
-
class DeepUnet(nn.Module):
|
129 |
-
def __init__(self, kernel_size, n_blocks, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
130 |
-
super(DeepUnet, self).__init__()
|
131 |
-
self.encoder = Encoder(in_channels, 128, en_de_layers, kernel_size, n_blocks, en_out_channels)
|
132 |
-
self.intermediate = Intermediate(self.encoder.out_channel // 2, self.encoder.out_channel, inter_layers, n_blocks)
|
133 |
-
self.decoder = Decoder(self.encoder.out_channel, en_de_layers, kernel_size, n_blocks)
|
134 |
-
|
135 |
-
def forward(self, x):
|
136 |
-
x, concat_tensors = self.encoder(x)
|
137 |
-
x = self.intermediate(x)
|
138 |
-
x = self.decoder(x, concat_tensors)
|
139 |
-
return x
|
140 |
-
|
141 |
-
class E2E(nn.Module):
|
142 |
-
def __init__(self, n_blocks, n_gru, kernel_size, en_de_layers=5, inter_layers=4, in_channels=1, en_out_channels=16):
|
143 |
-
super(E2E, self).__init__()
|
144 |
-
self.unet = DeepUnet(kernel_size, n_blocks, en_de_layers, inter_layers, in_channels, en_out_channels)
|
145 |
-
self.cnn = nn.Conv2d(en_out_channels, 3, (3, 3), padding=(1, 1))
|
146 |
-
|
147 |
-
if n_gru: self.fc = nn.Sequential(BiGRU(3 * 128, 256, n_gru), nn.Linear(512, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
148 |
-
else: self.fc = nn.Sequential(nn.Linear(3 * N_MELS, N_CLASS), nn.Dropout(0.25), nn.Sigmoid())
|
149 |
-
|
150 |
-
def forward(self, mel):
|
151 |
-
mel = mel.transpose(-1, -2).unsqueeze(1)
|
152 |
-
x = self.cnn(self.unet(mel)).transpose(1, 2).flatten(-2)
|
153 |
-
x = self.fc(x)
|
154 |
-
return x
|
155 |
-
|
156 |
-
class MelSpectrogram(torch.nn.Module):
|
157 |
-
def __init__(self, is_half, n_mel_channels, sample_rate, win_length, hop_length, n_fft=None, mel_fmin=0, mel_fmax=None, clamp=1e-5):
|
158 |
-
super().__init__()
|
159 |
-
n_fft = win_length if n_fft is None else n_fft
|
160 |
-
self.hann_window = {}
|
161 |
-
mel_basis = mel(sr=sample_rate, n_fft=n_fft, n_mels=n_mel_channels, fmin=mel_fmin, fmax=mel_fmax, htk=True)
|
162 |
-
mel_basis = torch.from_numpy(mel_basis).float()
|
163 |
-
self.register_buffer("mel_basis", mel_basis)
|
164 |
-
self.n_fft = win_length if n_fft is None else n_fft
|
165 |
-
self.hop_length = hop_length
|
166 |
-
self.win_length = win_length
|
167 |
-
self.sample_rate = sample_rate
|
168 |
-
self.n_mel_channels = n_mel_channels
|
169 |
-
self.clamp = clamp
|
170 |
-
self.is_half = is_half
|
171 |
-
|
172 |
-
def forward(self, audio, keyshift=0, speed=1, center=True):
|
173 |
-
factor = 2 ** (keyshift / 12)
|
174 |
-
n_fft_new = int(np.round(self.n_fft * factor))
|
175 |
-
win_length_new = int(np.round(self.win_length * factor))
|
176 |
-
hop_length_new = int(np.round(self.hop_length * speed))
|
177 |
-
keyshift_key = str(keyshift) + "_" + str(audio.device)
|
178 |
-
|
179 |
-
if keyshift_key not in self.hann_window: self.hann_window[keyshift_key] = torch.hann_window(win_length_new).to(audio.device)
|
180 |
-
|
181 |
-
fft = torch.stft(audio, n_fft=n_fft_new, hop_length=hop_length_new, win_length=win_length_new, window=self.hann_window[keyshift_key], center=center, return_complex=True)
|
182 |
-
magnitude = torch.sqrt(fft.real.pow(2) + fft.imag.pow(2))
|
183 |
-
|
184 |
-
if keyshift != 0:
|
185 |
-
size = self.n_fft // 2 + 1
|
186 |
-
resize = magnitude.size(1)
|
187 |
-
|
188 |
-
if resize < size: magnitude = F.pad(magnitude, (0, 0, 0, size - resize))
|
189 |
-
|
190 |
-
magnitude = magnitude[:, :size, :] * self.win_length / win_length_new
|
191 |
-
|
192 |
-
mel_output = torch.matmul(self.mel_basis, magnitude)
|
193 |
-
|
194 |
-
if self.is_half: mel_output = mel_output.half()
|
195 |
-
|
196 |
-
log_mel_spec = torch.log(torch.clamp(mel_output, min=self.clamp))
|
197 |
-
return log_mel_spec
|
198 |
-
|
199 |
-
class RMVPE:
|
200 |
-
def __init__(self, model_path, is_half, device=None):
|
201 |
-
self.resample_kernel = {}
|
202 |
-
model = E2E(4, 1, (2, 2))
|
203 |
-
ckpt = torch.load(model_path, map_location="cpu")
|
204 |
-
model.load_state_dict(ckpt)
|
205 |
-
model.eval()
|
206 |
-
|
207 |
-
if is_half: model = model.half()
|
208 |
-
|
209 |
-
self.model = model
|
210 |
-
self.resample_kernel = {}
|
211 |
-
self.is_half = is_half
|
212 |
-
self.device = device
|
213 |
-
self.mel_extractor = MelSpectrogram(is_half, N_MELS, 16000, 1024, 160, None, 30, 8000).to(device)
|
214 |
-
self.model = self.model.to(device)
|
215 |
-
cents_mapping = 20 * np.arange(N_CLASS) + 1997.3794084376191
|
216 |
-
self.cents_mapping = np.pad(cents_mapping, (4, 4))
|
217 |
-
|
218 |
-
def mel2hidden(self, mel):
|
219 |
-
with torch.no_grad():
|
220 |
-
n_frames = mel.shape[-1]
|
221 |
-
mel = F.pad(mel, (0, 32 * ((n_frames - 1) // 32 + 1) - n_frames), mode="reflect")
|
222 |
-
hidden = self.model(mel)
|
223 |
-
return hidden[:, :n_frames]
|
224 |
-
|
225 |
-
def decode(self, hidden, thred=0.03):
|
226 |
-
cents_pred = self.to_local_average_cents(hidden, thred=thred)
|
227 |
-
f0 = 10 * (2 ** (cents_pred / 1200))
|
228 |
-
f0[f0 == 10] = 0
|
229 |
-
return f0
|
230 |
-
|
231 |
-
def infer_from_audio(self, audio, thred=0.03):
|
232 |
-
audio = torch.from_numpy(audio).float().to(self.device).unsqueeze(0)
|
233 |
-
mel = self.mel_extractor(audio, center=True)
|
234 |
-
hidden = self.mel2hidden(mel)
|
235 |
-
hidden = hidden.squeeze(0).cpu().numpy()
|
236 |
-
|
237 |
-
if self.is_half: hidden = hidden.astype("float32")
|
238 |
-
|
239 |
-
f0 = self.decode(hidden, thred=thred)
|
240 |
-
return f0
|
241 |
-
|
242 |
-
def to_local_average_cents(self, salience, thred=0.05):
|
243 |
-
center = np.argmax(salience, axis=1)
|
244 |
-
salience = np.pad(salience, ((0, 0), (4, 4)))
|
245 |
-
center += 4
|
246 |
-
todo_salience = []
|
247 |
-
todo_cents_mapping = []
|
248 |
-
starts = center - 4
|
249 |
-
ends = center + 5
|
250 |
-
|
251 |
-
for idx in range(salience.shape[0]):
|
252 |
-
todo_salience.append(salience[:, starts[idx] : ends[idx]][idx])
|
253 |
-
todo_cents_mapping.append(self.cents_mapping[starts[idx] : ends[idx]])
|
254 |
-
|
255 |
-
todo_salience = np.array(todo_salience)
|
256 |
-
todo_cents_mapping = np.array(todo_cents_mapping)
|
257 |
-
product_sum = np.sum(todo_salience * todo_cents_mapping, 1)
|
258 |
-
weight_sum = np.sum(todo_salience, 1)
|
259 |
-
devided = product_sum / weight_sum
|
260 |
-
maxx = np.max(salience, axis=1)
|
261 |
-
devided[maxx <= thred] = 0
|
262 |
-
return devided
|
263 |
-
|
264 |
-
class BiGRU(nn.Module):
|
265 |
-
def __init__(self, input_features, hidden_features, num_layers):
|
266 |
-
super(BiGRU, self).__init__()
|
267 |
-
self.gru = nn.GRU(input_features, hidden_features, num_layers=num_layers, batch_first=True, bidirectional=True)
|
268 |
-
|
269 |
-
def forward(self, x):
|
270 |
-
return self.gru(x)[0]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/common_separator.py
DELETED
@@ -1,270 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import gc
|
3 |
-
import sys
|
4 |
-
import torch
|
5 |
-
import librosa
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import soundfile as sf
|
9 |
-
|
10 |
-
from logging import Logger
|
11 |
-
from pydub import AudioSegment
|
12 |
-
|
13 |
-
now_dir = os.getcwd()
|
14 |
-
sys.path.append(now_dir)
|
15 |
-
|
16 |
-
from . import spec_utils
|
17 |
-
from main.configs.config import Config
|
18 |
-
|
19 |
-
translations = Config().translations
|
20 |
-
|
21 |
-
class CommonSeparator:
|
22 |
-
ALL_STEMS = "All Stems"
|
23 |
-
VOCAL_STEM = "Vocals"
|
24 |
-
INST_STEM = "Instrumental"
|
25 |
-
OTHER_STEM = "Other"
|
26 |
-
BASS_STEM = "Bass"
|
27 |
-
DRUM_STEM = "Drums"
|
28 |
-
GUITAR_STEM = "Guitar"
|
29 |
-
PIANO_STEM = "Piano"
|
30 |
-
SYNTH_STEM = "Synthesizer"
|
31 |
-
STRINGS_STEM = "Strings"
|
32 |
-
WOODWINDS_STEM = "Woodwinds"
|
33 |
-
BRASS_STEM = "Brass"
|
34 |
-
WIND_INST_STEM = "Wind Inst"
|
35 |
-
NO_OTHER_STEM = "No Other"
|
36 |
-
NO_BASS_STEM = "No Bass"
|
37 |
-
NO_DRUM_STEM = "No Drums"
|
38 |
-
NO_GUITAR_STEM = "No Guitar"
|
39 |
-
NO_PIANO_STEM = "No Piano"
|
40 |
-
NO_SYNTH_STEM = "No Synthesizer"
|
41 |
-
NO_STRINGS_STEM = "No Strings"
|
42 |
-
NO_WOODWINDS_STEM = "No Woodwinds"
|
43 |
-
NO_WIND_INST_STEM = "No Wind Inst"
|
44 |
-
NO_BRASS_STEM = "No Brass"
|
45 |
-
PRIMARY_STEM = "Primary Stem"
|
46 |
-
SECONDARY_STEM = "Secondary Stem"
|
47 |
-
LEAD_VOCAL_STEM = "lead_only"
|
48 |
-
BV_VOCAL_STEM = "backing_only"
|
49 |
-
LEAD_VOCAL_STEM_I = "with_lead_vocals"
|
50 |
-
BV_VOCAL_STEM_I = "with_backing_vocals"
|
51 |
-
LEAD_VOCAL_STEM_LABEL = "Lead Vocals"
|
52 |
-
BV_VOCAL_STEM_LABEL = "Backing Vocals"
|
53 |
-
NO_STEM = "No "
|
54 |
-
|
55 |
-
STEM_PAIR_MAPPER = {VOCAL_STEM: INST_STEM, INST_STEM: VOCAL_STEM, LEAD_VOCAL_STEM: BV_VOCAL_STEM, BV_VOCAL_STEM: LEAD_VOCAL_STEM, PRIMARY_STEM: SECONDARY_STEM}
|
56 |
-
|
57 |
-
NON_ACCOM_STEMS = (VOCAL_STEM, OTHER_STEM, BASS_STEM, DRUM_STEM, GUITAR_STEM, PIANO_STEM, SYNTH_STEM, STRINGS_STEM, WOODWINDS_STEM, BRASS_STEM, WIND_INST_STEM)
|
58 |
-
|
59 |
-
|
60 |
-
def __init__(self, config):
|
61 |
-
self.logger: Logger = config.get("logger")
|
62 |
-
self.log_level: int = config.get("log_level")
|
63 |
-
self.torch_device = config.get("torch_device")
|
64 |
-
self.torch_device_cpu = config.get("torch_device_cpu")
|
65 |
-
self.torch_device_mps = config.get("torch_device_mps")
|
66 |
-
self.onnx_execution_provider = config.get("onnx_execution_provider")
|
67 |
-
self.model_name = config.get("model_name")
|
68 |
-
self.model_path = config.get("model_path")
|
69 |
-
self.model_data = config.get("model_data")
|
70 |
-
self.output_dir = config.get("output_dir")
|
71 |
-
self.output_format = config.get("output_format")
|
72 |
-
self.output_bitrate = config.get("output_bitrate")
|
73 |
-
self.normalization_threshold = config.get("normalization_threshold")
|
74 |
-
self.enable_denoise = config.get("enable_denoise")
|
75 |
-
self.output_single_stem = config.get("output_single_stem")
|
76 |
-
self.invert_using_spec = config.get("invert_using_spec")
|
77 |
-
self.sample_rate = config.get("sample_rate")
|
78 |
-
|
79 |
-
self.primary_stem_name = None
|
80 |
-
self.secondary_stem_name = None
|
81 |
-
|
82 |
-
if "training" in self.model_data and "instruments" in self.model_data["training"]:
|
83 |
-
instruments = self.model_data["training"]["instruments"]
|
84 |
-
|
85 |
-
if instruments:
|
86 |
-
self.primary_stem_name = instruments[0]
|
87 |
-
self.secondary_stem_name = instruments[1] if len(instruments) > 1 else self.secondary_stem(self.primary_stem_name)
|
88 |
-
|
89 |
-
if self.primary_stem_name is None:
|
90 |
-
self.primary_stem_name = self.model_data.get("primary_stem", "Vocals")
|
91 |
-
self.secondary_stem_name = self.secondary_stem(self.primary_stem_name)
|
92 |
-
|
93 |
-
self.is_karaoke = self.model_data.get("is_karaoke", False)
|
94 |
-
self.is_bv_model = self.model_data.get("is_bv_model", False)
|
95 |
-
self.bv_model_rebalance = self.model_data.get("is_bv_model_rebalanced", 0)
|
96 |
-
|
97 |
-
self.logger.debug(translations["info"].format(model_name=self.model_name, model_path=self.model_path))
|
98 |
-
self.logger.debug(translations["info_2"].format(output_dir=self.output_dir, output_format=self.output_format))
|
99 |
-
self.logger.debug(translations["info_3"].format(normalization_threshold=self.normalization_threshold))
|
100 |
-
self.logger.debug(translations["info_4"].format(enable_denoise=self.enable_denoise, output_single_stem=self.output_single_stem))
|
101 |
-
self.logger.debug(translations["info_5"].format(invert_using_spec=self.invert_using_spec, sample_rate=self.sample_rate))
|
102 |
-
self.logger.debug(translations["info_6"].format(primary_stem_name=self.primary_stem_name, secondary_stem_name=self.secondary_stem_name))
|
103 |
-
self.logger.debug(translations["info_7"].format(is_karaoke=self.is_karaoke, is_bv_model=self.is_bv_model, bv_model_rebalance=self.bv_model_rebalance))
|
104 |
-
|
105 |
-
self.audio_file_path = None
|
106 |
-
self.audio_file_base = None
|
107 |
-
self.primary_source = None
|
108 |
-
self.secondary_source = None
|
109 |
-
self.primary_stem_output_path = None
|
110 |
-
self.secondary_stem_output_path = None
|
111 |
-
self.cached_sources_map = {}
|
112 |
-
|
113 |
-
def secondary_stem(self, primary_stem: str):
|
114 |
-
primary_stem = primary_stem if primary_stem else self.NO_STEM
|
115 |
-
|
116 |
-
return self.STEM_PAIR_MAPPER[primary_stem] if primary_stem in self.STEM_PAIR_MAPPER else primary_stem.replace(self.NO_STEM, "") if self.NO_STEM in primary_stem else f"{self.NO_STEM}{primary_stem}"
|
117 |
-
|
118 |
-
def separate(self, audio_file_path):
|
119 |
-
pass
|
120 |
-
|
121 |
-
def final_process(self, stem_path, source, stem_name):
|
122 |
-
self.logger.debug(translations["success_process"].format(stem_name=stem_name))
|
123 |
-
self.write_audio(stem_path, source)
|
124 |
-
|
125 |
-
return {stem_name: source}
|
126 |
-
|
127 |
-
def cached_sources_clear(self):
|
128 |
-
self.cached_sources_map = {}
|
129 |
-
|
130 |
-
def cached_source_callback(self, model_architecture, model_name=None):
|
131 |
-
model, sources = None, None
|
132 |
-
mapper = self.cached_sources_map[model_architecture]
|
133 |
-
|
134 |
-
for key, value in mapper.items():
|
135 |
-
if model_name in key:
|
136 |
-
model = key
|
137 |
-
sources = value
|
138 |
-
|
139 |
-
return model, sources
|
140 |
-
|
141 |
-
def cached_model_source_holder(self, model_architecture, sources, model_name=None):
|
142 |
-
self.cached_sources_map[model_architecture] = {**self.cached_sources_map.get(model_architecture, {}), **{model_name: sources}}
|
143 |
-
|
144 |
-
def prepare_mix(self, mix):
|
145 |
-
audio_path = mix
|
146 |
-
|
147 |
-
if not isinstance(mix, np.ndarray):
|
148 |
-
self.logger.debug(f"{translations['load_audio']}: {mix}")
|
149 |
-
mix, sr = librosa.load(mix, mono=False, sr=self.sample_rate)
|
150 |
-
self.logger.debug(translations["load_audio_success"].format(sr=sr, shape=mix.shape))
|
151 |
-
else:
|
152 |
-
self.logger.debug(translations["convert_mix"])
|
153 |
-
mix = mix.T
|
154 |
-
self.logger.debug(translations["convert_shape"].format(shape=mix.shape))
|
155 |
-
|
156 |
-
if isinstance(audio_path, str):
|
157 |
-
if not np.any(mix):
|
158 |
-
error_msg = translations["audio_not_valid"].format(audio_path=audio_path)
|
159 |
-
self.logger.error(error_msg)
|
160 |
-
raise ValueError(error_msg)
|
161 |
-
else: self.logger.debug(translations["audio_valid"])
|
162 |
-
|
163 |
-
if mix.ndim == 1:
|
164 |
-
self.logger.debug(translations["mix_single"])
|
165 |
-
mix = np.asfortranarray([mix, mix])
|
166 |
-
self.logger.debug(translations["convert_mix_audio"])
|
167 |
-
|
168 |
-
self.logger.debug(translations["mix_success_2"])
|
169 |
-
return mix
|
170 |
-
|
171 |
-
def write_audio(self, stem_path: str, stem_source):
|
172 |
-
duration_seconds = librosa.get_duration(filename=self.audio_file_path)
|
173 |
-
duration_hours = duration_seconds / 3600
|
174 |
-
self.logger.info(translations["duration"].format(duration_hours=f"{duration_hours:.2f}", duration_seconds=f"{duration_seconds:.2f}"))
|
175 |
-
|
176 |
-
if duration_hours >= 1:
|
177 |
-
self.logger.warning(translations["write"].format(name="soundfile"))
|
178 |
-
self.write_audio_soundfile(stem_path, stem_source)
|
179 |
-
else:
|
180 |
-
self.logger.info(translations["write"].format(name="pydub"))
|
181 |
-
self.write_audio_pydub(stem_path, stem_source)
|
182 |
-
|
183 |
-
def write_audio_pydub(self, stem_path: str, stem_source):
|
184 |
-
self.logger.debug(f"{translations['write_audio'].format(name='write_audio_pydub')} {stem_path}")
|
185 |
-
|
186 |
-
stem_source = spec_utils.normalize(wave=stem_source, max_peak=self.normalization_threshold)
|
187 |
-
|
188 |
-
if np.max(np.abs(stem_source)) < 1e-6:
|
189 |
-
self.logger.warning(translations["original_not_valid"])
|
190 |
-
return
|
191 |
-
|
192 |
-
if self.output_dir:
|
193 |
-
os.makedirs(self.output_dir, exist_ok=True)
|
194 |
-
stem_path = os.path.join(self.output_dir, stem_path)
|
195 |
-
|
196 |
-
self.logger.debug(f"{translations['shape_audio']}: {stem_source.shape}")
|
197 |
-
self.logger.debug(f"{translations['convert_data']}: {stem_source.dtype}")
|
198 |
-
|
199 |
-
if stem_source.dtype != np.int16:
|
200 |
-
stem_source = (stem_source * 32767).astype(np.int16)
|
201 |
-
self.logger.debug(translations["original_source_to_int16"])
|
202 |
-
|
203 |
-
stem_source_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
204 |
-
stem_source_interleaved[0::2] = stem_source[:, 0]
|
205 |
-
stem_source_interleaved[1::2] = stem_source[:, 1]
|
206 |
-
|
207 |
-
self.logger.debug(f"{translations['shape_audio_2']}: {stem_source_interleaved.shape}")
|
208 |
-
|
209 |
-
try:
|
210 |
-
audio_segment = AudioSegment(stem_source_interleaved.tobytes(), frame_rate=self.sample_rate, sample_width=stem_source.dtype.itemsize, channels=2)
|
211 |
-
self.logger.debug(translations["create_audiosegment"])
|
212 |
-
except (IOError, ValueError) as e:
|
213 |
-
self.logger.error(f"{translations['create_audiosegment_error']}: {e}")
|
214 |
-
return
|
215 |
-
|
216 |
-
file_format = stem_path.lower().split(".")[-1]
|
217 |
-
|
218 |
-
if file_format == "m4a": file_format = "mp4"
|
219 |
-
elif file_format == "mka": file_format = "matroska"
|
220 |
-
|
221 |
-
bitrate = "320k" if file_format == "mp3" and self.output_bitrate is None else self.output_bitrate
|
222 |
-
|
223 |
-
try:
|
224 |
-
audio_segment.export(stem_path, format=file_format, bitrate=bitrate)
|
225 |
-
self.logger.debug(f"{translations['export_success']} {stem_path}")
|
226 |
-
except (IOError, ValueError) as e:
|
227 |
-
self.logger.error(f"{translations['export_error']}: {e}")
|
228 |
-
|
229 |
-
def write_audio_soundfile(self, stem_path: str, stem_source):
|
230 |
-
self.logger.debug(f"{translations['write_audio'].format(name='write_audio_soundfile')}: {stem_path}")
|
231 |
-
|
232 |
-
if stem_source.shape[1] == 2:
|
233 |
-
if stem_source.flags["F_CONTIGUOUS"]: stem_source = np.ascontiguousarray(stem_source)
|
234 |
-
else:
|
235 |
-
stereo_interleaved = np.empty((2 * stem_source.shape[0],), dtype=np.int16)
|
236 |
-
stereo_interleaved[0::2] = stem_source[:, 0]
|
237 |
-
|
238 |
-
stereo_interleaved[1::2] = stem_source[:, 1]
|
239 |
-
stem_source = stereo_interleaved
|
240 |
-
|
241 |
-
self.logger.debug(f"{translations['shape_audio_2']}: {stem_source.shape}")
|
242 |
-
|
243 |
-
try:
|
244 |
-
sf.write(stem_path, stem_source, self.sample_rate)
|
245 |
-
self.logger.debug(f"{translations['export_success']} {stem_path}")
|
246 |
-
except Exception as e:
|
247 |
-
self.logger.error(f"{translations['export_error']}: {e}")
|
248 |
-
|
249 |
-
def clear_gpu_cache(self):
|
250 |
-
self.logger.debug(translations["clean"])
|
251 |
-
gc.collect()
|
252 |
-
|
253 |
-
if self.torch_device == torch.device("mps"):
|
254 |
-
self.logger.debug(translations["clean_cache"].format(name="MPS"))
|
255 |
-
torch.mps.empty_cache()
|
256 |
-
|
257 |
-
if self.torch_device == torch.device("cuda"):
|
258 |
-
self.logger.debug(translations["clean_cache"].format(name="CUDA"))
|
259 |
-
torch.cuda.empty_cache()
|
260 |
-
|
261 |
-
def clear_file_specific_paths(self):
|
262 |
-
self.logger.info(translations["del_path"])
|
263 |
-
self.audio_file_path = None
|
264 |
-
self.audio_file_base = None
|
265 |
-
|
266 |
-
self.primary_source = None
|
267 |
-
self.secondary_source = None
|
268 |
-
|
269 |
-
self.primary_stem_output_path = None
|
270 |
-
self.secondary_stem_output_path = None
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/apply.py
DELETED
@@ -1,280 +0,0 @@
|
|
1 |
-
import tqdm
|
2 |
-
import torch
|
3 |
-
import random
|
4 |
-
|
5 |
-
import typing as tp
|
6 |
-
|
7 |
-
from torch import nn
|
8 |
-
from torch.nn import functional as F
|
9 |
-
from concurrent.futures import ThreadPoolExecutor
|
10 |
-
|
11 |
-
from .demucs import Demucs
|
12 |
-
from .hdemucs import HDemucs
|
13 |
-
from .htdemucs import HTDemucs
|
14 |
-
from .utils import center_trim
|
15 |
-
|
16 |
-
|
17 |
-
Model = tp.Union[Demucs, HDemucs, HTDemucs]
|
18 |
-
|
19 |
-
class DummyPoolExecutor:
|
20 |
-
class DummyResult:
|
21 |
-
def __init__(self, func, *args, **kwargs):
|
22 |
-
self.func = func
|
23 |
-
self.args = args
|
24 |
-
self.kwargs = kwargs
|
25 |
-
|
26 |
-
def result(self):
|
27 |
-
return self.func(*self.args, **self.kwargs)
|
28 |
-
|
29 |
-
def __init__(self, workers=0):
|
30 |
-
pass
|
31 |
-
|
32 |
-
def submit(self, func, *args, **kwargs):
|
33 |
-
return DummyPoolExecutor.DummyResult(func, *args, **kwargs)
|
34 |
-
|
35 |
-
def __enter__(self):
|
36 |
-
return self
|
37 |
-
|
38 |
-
def __exit__(self, exc_type, exc_value, exc_tb):
|
39 |
-
return
|
40 |
-
|
41 |
-
class BagOfModels(nn.Module):
|
42 |
-
def __init__(self, models: tp.List[Model], weights: tp.Optional[tp.List[tp.List[float]]] = None, segment: tp.Optional[float] = None):
|
43 |
-
super().__init__()
|
44 |
-
assert len(models) > 0
|
45 |
-
first = models[0]
|
46 |
-
|
47 |
-
for other in models:
|
48 |
-
assert other.sources == first.sources
|
49 |
-
assert other.samplerate == first.samplerate
|
50 |
-
assert other.audio_channels == first.audio_channels
|
51 |
-
|
52 |
-
if segment is not None: other.segment = segment
|
53 |
-
|
54 |
-
self.audio_channels = first.audio_channels
|
55 |
-
self.samplerate = first.samplerate
|
56 |
-
self.sources = first.sources
|
57 |
-
self.models = nn.ModuleList(models)
|
58 |
-
|
59 |
-
if weights is None: weights = [[1.0 for _ in first.sources] for _ in models]
|
60 |
-
else:
|
61 |
-
assert len(weights) == len(models)
|
62 |
-
|
63 |
-
for weight in weights:
|
64 |
-
assert len(weight) == len(first.sources)
|
65 |
-
|
66 |
-
self.weights = weights
|
67 |
-
|
68 |
-
def forward(self, x):
|
69 |
-
raise NotImplementedError("`apply_model`")
|
70 |
-
|
71 |
-
|
72 |
-
class TensorChunk:
|
73 |
-
def __init__(self, tensor, offset=0, length=None):
|
74 |
-
total_length = tensor.shape[-1]
|
75 |
-
assert offset >= 0
|
76 |
-
assert offset < total_length
|
77 |
-
|
78 |
-
length = total_length - offset if length is None else min(total_length - offset, length)
|
79 |
-
|
80 |
-
if isinstance(tensor, TensorChunk):
|
81 |
-
self.tensor = tensor.tensor
|
82 |
-
self.offset = offset + tensor.offset
|
83 |
-
else:
|
84 |
-
self.tensor = tensor
|
85 |
-
self.offset = offset
|
86 |
-
|
87 |
-
self.length = length
|
88 |
-
self.device = tensor.device
|
89 |
-
|
90 |
-
@property
|
91 |
-
def shape(self):
|
92 |
-
shape = list(self.tensor.shape)
|
93 |
-
shape[-1] = self.length
|
94 |
-
|
95 |
-
return shape
|
96 |
-
|
97 |
-
def padded(self, target_length):
|
98 |
-
delta = target_length - self.length
|
99 |
-
total_length = self.tensor.shape[-1]
|
100 |
-
assert delta >= 0
|
101 |
-
|
102 |
-
start = self.offset - delta // 2
|
103 |
-
end = start + target_length
|
104 |
-
|
105 |
-
correct_start = max(0, start)
|
106 |
-
correct_end = min(total_length, end)
|
107 |
-
|
108 |
-
pad_left = correct_start - start
|
109 |
-
pad_right = end - correct_end
|
110 |
-
|
111 |
-
out = F.pad(self.tensor[..., correct_start:correct_end], (pad_left, pad_right))
|
112 |
-
|
113 |
-
assert out.shape[-1] == target_length
|
114 |
-
|
115 |
-
return out
|
116 |
-
|
117 |
-
|
118 |
-
def tensor_chunk(tensor_or_chunk):
|
119 |
-
if isinstance(tensor_or_chunk, TensorChunk): return tensor_or_chunk
|
120 |
-
else:
|
121 |
-
assert isinstance(tensor_or_chunk, torch.Tensor)
|
122 |
-
|
123 |
-
return TensorChunk(tensor_or_chunk)
|
124 |
-
|
125 |
-
|
126 |
-
def apply_model(model, mix, shifts=1, split=True, overlap=0.25, transition_power=1.0, static_shifts=1, set_progress_bar=None, device=None, progress=False, num_workers=0, pool=None):
|
127 |
-
global fut_length
|
128 |
-
global bag_num
|
129 |
-
global prog_bar
|
130 |
-
|
131 |
-
device = mix.device if device is None else torch.device(device)
|
132 |
-
|
133 |
-
if pool is None: pool = ThreadPoolExecutor(num_workers) if num_workers > 0 and device.type == "cpu" else DummyPoolExecutor()
|
134 |
-
|
135 |
-
kwargs = {
|
136 |
-
"shifts": shifts,
|
137 |
-
"split": split,
|
138 |
-
"overlap": overlap,
|
139 |
-
"transition_power": transition_power,
|
140 |
-
"progress": progress,
|
141 |
-
"device": device,
|
142 |
-
"pool": pool,
|
143 |
-
"set_progress_bar": set_progress_bar,
|
144 |
-
"static_shifts": static_shifts,
|
145 |
-
}
|
146 |
-
|
147 |
-
if isinstance(model, BagOfModels):
|
148 |
-
estimates = 0
|
149 |
-
totals = [0] * len(model.sources)
|
150 |
-
bag_num = len(model.models)
|
151 |
-
fut_length = 0
|
152 |
-
prog_bar = 0
|
153 |
-
current_model = 0
|
154 |
-
|
155 |
-
for sub_model, weight in zip(model.models, model.weights):
|
156 |
-
original_model_device = next(iter(sub_model.parameters())).device
|
157 |
-
sub_model.to(device)
|
158 |
-
fut_length += fut_length
|
159 |
-
current_model += 1
|
160 |
-
|
161 |
-
out = apply_model(sub_model, mix, **kwargs)
|
162 |
-
sub_model.to(original_model_device)
|
163 |
-
|
164 |
-
for k, inst_weight in enumerate(weight):
|
165 |
-
out[:, k, :, :] *= inst_weight
|
166 |
-
totals[k] += inst_weight
|
167 |
-
|
168 |
-
estimates += out
|
169 |
-
del out
|
170 |
-
|
171 |
-
for k in range(estimates.shape[1]):
|
172 |
-
estimates[:, k, :, :] /= totals[k]
|
173 |
-
|
174 |
-
return estimates
|
175 |
-
|
176 |
-
model.to(device)
|
177 |
-
model.eval()
|
178 |
-
|
179 |
-
assert transition_power >= 1
|
180 |
-
|
181 |
-
batch, channels, length = mix.shape
|
182 |
-
|
183 |
-
if shifts:
|
184 |
-
kwargs["shifts"] = 0
|
185 |
-
max_shift = int(0.5 * model.samplerate)
|
186 |
-
mix = tensor_chunk(mix)
|
187 |
-
padded_mix = mix.padded(length + 2 * max_shift)
|
188 |
-
out = 0
|
189 |
-
|
190 |
-
for _ in range(shifts):
|
191 |
-
offset = random.randint(0, max_shift)
|
192 |
-
shifted = TensorChunk(padded_mix, offset, length + max_shift - offset)
|
193 |
-
shifted_out = apply_model(model, shifted, **kwargs)
|
194 |
-
out += shifted_out[..., max_shift - offset :]
|
195 |
-
|
196 |
-
out /= shifts
|
197 |
-
|
198 |
-
return out
|
199 |
-
elif split:
|
200 |
-
kwargs["split"] = False
|
201 |
-
out = torch.zeros(batch, len(model.sources), channels, length, device=mix.device)
|
202 |
-
sum_weight = torch.zeros(length, device=mix.device)
|
203 |
-
segment = int(model.samplerate * model.segment)
|
204 |
-
stride = int((1 - overlap) * segment)
|
205 |
-
offsets = range(0, length, stride)
|
206 |
-
|
207 |
-
weight = torch.cat([torch.arange(1, segment // 2 + 1, device=device), torch.arange(segment - segment // 2, 0, -1, device=device)])
|
208 |
-
assert len(weight) == segment
|
209 |
-
|
210 |
-
weight = (weight / weight.max()) ** transition_power
|
211 |
-
futures = []
|
212 |
-
|
213 |
-
for offset in offsets:
|
214 |
-
chunk = TensorChunk(mix, offset, segment)
|
215 |
-
future = pool.submit(apply_model, model, chunk, **kwargs)
|
216 |
-
futures.append((future, offset))
|
217 |
-
offset += segment
|
218 |
-
|
219 |
-
if progress: futures = tqdm.tqdm(futures)
|
220 |
-
|
221 |
-
for future, offset in futures:
|
222 |
-
if set_progress_bar:
|
223 |
-
fut_length = len(futures) * bag_num * static_shifts
|
224 |
-
prog_bar += 1
|
225 |
-
set_progress_bar(0.1, (0.8 / fut_length * prog_bar))
|
226 |
-
|
227 |
-
chunk_out = future.result()
|
228 |
-
chunk_length = chunk_out.shape[-1]
|
229 |
-
|
230 |
-
out[..., offset : offset + segment] += (weight[:chunk_length] * chunk_out).to(mix.device)
|
231 |
-
sum_weight[offset : offset + segment] += weight[:chunk_length].to(mix.device)
|
232 |
-
|
233 |
-
assert sum_weight.min() > 0
|
234 |
-
|
235 |
-
out /= sum_weight
|
236 |
-
|
237 |
-
return out
|
238 |
-
else:
|
239 |
-
valid_length = model.valid_length(length) if hasattr(model, "valid_length") else length
|
240 |
-
|
241 |
-
mix = tensor_chunk(mix)
|
242 |
-
padded_mix = mix.padded(valid_length).to(device)
|
243 |
-
|
244 |
-
with torch.no_grad():
|
245 |
-
out = model(padded_mix)
|
246 |
-
|
247 |
-
return center_trim(out, length)
|
248 |
-
|
249 |
-
|
250 |
-
def demucs_segments(demucs_segment, demucs_model):
|
251 |
-
if demucs_segment == "Default":
|
252 |
-
segment = None
|
253 |
-
|
254 |
-
if isinstance(demucs_model, BagOfModels):
|
255 |
-
if segment is not None:
|
256 |
-
for sub in demucs_model.models:
|
257 |
-
sub.segment = segment
|
258 |
-
else:
|
259 |
-
if segment is not None: sub.segment = segment
|
260 |
-
else:
|
261 |
-
try:
|
262 |
-
segment = int(demucs_segment)
|
263 |
-
|
264 |
-
if isinstance(demucs_model, BagOfModels):
|
265 |
-
if segment is not None:
|
266 |
-
for sub in demucs_model.models:
|
267 |
-
sub.segment = segment
|
268 |
-
else:
|
269 |
-
if segment is not None: sub.segment = segment
|
270 |
-
except:
|
271 |
-
segment = None
|
272 |
-
|
273 |
-
if isinstance(demucs_model, BagOfModels):
|
274 |
-
if segment is not None:
|
275 |
-
for sub in demucs_model.models:
|
276 |
-
sub.segment = segment
|
277 |
-
else:
|
278 |
-
if segment is not None: sub.segment = segment
|
279 |
-
|
280 |
-
return demucs_model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/demucs.py
DELETED
@@ -1,340 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
import julius
|
4 |
-
|
5 |
-
import typing as tp
|
6 |
-
|
7 |
-
from torch import nn
|
8 |
-
|
9 |
-
from torch.nn import functional as F
|
10 |
-
|
11 |
-
from .utils import center_trim
|
12 |
-
from .states import capture_init
|
13 |
-
|
14 |
-
|
15 |
-
def unfold(a, kernel_size, stride):
|
16 |
-
*shape, length = a.shape
|
17 |
-
n_frames = math.ceil(length / stride)
|
18 |
-
|
19 |
-
tgt_length = (n_frames - 1) * stride + kernel_size
|
20 |
-
a = F.pad(a, (0, tgt_length - length))
|
21 |
-
strides = list(a.stride())
|
22 |
-
|
23 |
-
assert strides[-1] == 1
|
24 |
-
|
25 |
-
strides = strides[:-1] + [stride, 1]
|
26 |
-
|
27 |
-
return a.as_strided([*shape, n_frames, kernel_size], strides)
|
28 |
-
|
29 |
-
def rescale_conv(conv, reference):
|
30 |
-
scale = (conv.weight.std().detach() / reference) ** 0.5
|
31 |
-
conv.weight.data /= scale
|
32 |
-
|
33 |
-
if conv.bias is not None: conv.bias.data /= scale
|
34 |
-
|
35 |
-
def rescale_module(module, reference):
|
36 |
-
for sub in module.modules():
|
37 |
-
if isinstance(sub, (nn.Conv1d, nn.ConvTranspose1d, nn.Conv2d, nn.ConvTranspose2d)): rescale_conv(sub, reference)
|
38 |
-
|
39 |
-
class BLSTM(nn.Module):
|
40 |
-
def __init__(self, dim, layers=1, max_steps=None, skip=False):
|
41 |
-
super().__init__()
|
42 |
-
assert max_steps is None or max_steps % 4 == 0
|
43 |
-
self.max_steps = max_steps
|
44 |
-
self.lstm = nn.LSTM(bidirectional=True, num_layers=layers, hidden_size=dim, input_size=dim)
|
45 |
-
self.linear = nn.Linear(2 * dim, dim)
|
46 |
-
self.skip = skip
|
47 |
-
|
48 |
-
def forward(self, x):
|
49 |
-
B, C, T = x.shape
|
50 |
-
y = x
|
51 |
-
framed = False
|
52 |
-
|
53 |
-
if self.max_steps is not None and T > self.max_steps:
|
54 |
-
width = self.max_steps
|
55 |
-
stride = width // 2
|
56 |
-
frames = unfold(x, width, stride)
|
57 |
-
nframes = frames.shape[2]
|
58 |
-
framed = True
|
59 |
-
x = frames.permute(0, 2, 1, 3).reshape(-1, C, width)
|
60 |
-
|
61 |
-
x = x.permute(2, 0, 1)
|
62 |
-
|
63 |
-
x = self.lstm(x)[0]
|
64 |
-
x = self.linear(x)
|
65 |
-
x = x.permute(1, 2, 0)
|
66 |
-
|
67 |
-
if framed:
|
68 |
-
out = []
|
69 |
-
frames = x.reshape(B, -1, C, width)
|
70 |
-
limit = stride // 2
|
71 |
-
|
72 |
-
for k in range(nframes):
|
73 |
-
if k == 0: out.append(frames[:, k, :, :-limit])
|
74 |
-
elif k == nframes - 1: out.append(frames[:, k, :, limit:])
|
75 |
-
else: out.append(frames[:, k, :, limit:-limit])
|
76 |
-
|
77 |
-
out = torch.cat(out, -1)
|
78 |
-
out = out[..., :T]
|
79 |
-
x = out
|
80 |
-
|
81 |
-
if self.skip: x = x + y
|
82 |
-
|
83 |
-
return x
|
84 |
-
|
85 |
-
class LayerScale(nn.Module):
|
86 |
-
def __init__(self, channels: int, init: float = 0):
|
87 |
-
super().__init__()
|
88 |
-
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
89 |
-
self.scale.data[:] = init
|
90 |
-
|
91 |
-
def forward(self, x):
|
92 |
-
return self.scale[:, None] * x
|
93 |
-
|
94 |
-
class DConv(nn.Module):
|
95 |
-
def __init__(self, channels: int, compress: float = 4, depth: int = 2, init: float = 1e-4, norm=True, attn=False, heads=4, ndecay=4, lstm=False, gelu=True, kernel=3, dilate=True):
|
96 |
-
super().__init__()
|
97 |
-
assert kernel % 2 == 1
|
98 |
-
self.channels = channels
|
99 |
-
self.compress = compress
|
100 |
-
self.depth = abs(depth)
|
101 |
-
dilate = depth > 0
|
102 |
-
|
103 |
-
norm_fn: tp.Callable[[int], nn.Module]
|
104 |
-
norm_fn = lambda d: nn.Identity()
|
105 |
-
|
106 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(1, d)
|
107 |
-
|
108 |
-
hidden = int(channels / compress)
|
109 |
-
|
110 |
-
act: tp.Type[nn.Module]
|
111 |
-
act = nn.GELU if gelu else nn.ReLU
|
112 |
-
|
113 |
-
self.layers = nn.ModuleList([])
|
114 |
-
|
115 |
-
for d in range(self.depth):
|
116 |
-
dilation = 2**d if dilate else 1
|
117 |
-
padding = dilation * (kernel // 2)
|
118 |
-
|
119 |
-
mods = [
|
120 |
-
nn.Conv1d(channels, hidden, kernel, dilation=dilation, padding=padding),
|
121 |
-
norm_fn(hidden),
|
122 |
-
act(),
|
123 |
-
nn.Conv1d(hidden, 2 * channels, 1),
|
124 |
-
norm_fn(2 * channels),
|
125 |
-
nn.GLU(1),
|
126 |
-
LayerScale(channels, init),
|
127 |
-
]
|
128 |
-
|
129 |
-
if attn: mods.insert(3, LocalState(hidden, heads=heads, ndecay=ndecay))
|
130 |
-
if lstm: mods.insert(3, BLSTM(hidden, layers=2, max_steps=200, skip=True))
|
131 |
-
|
132 |
-
layer = nn.Sequential(*mods)
|
133 |
-
self.layers.append(layer)
|
134 |
-
|
135 |
-
def forward(self, x):
|
136 |
-
for layer in self.layers:
|
137 |
-
x = x + layer(x)
|
138 |
-
|
139 |
-
return x
|
140 |
-
|
141 |
-
class LocalState(nn.Module):
|
142 |
-
def __init__(self, channels: int, heads: int = 4, nfreqs: int = 0, ndecay: int = 4):
|
143 |
-
super().__init__()
|
144 |
-
|
145 |
-
assert channels % heads == 0, (channels, heads)
|
146 |
-
|
147 |
-
self.heads = heads
|
148 |
-
self.nfreqs = nfreqs
|
149 |
-
self.ndecay = ndecay
|
150 |
-
self.content = nn.Conv1d(channels, channels, 1)
|
151 |
-
self.query = nn.Conv1d(channels, channels, 1)
|
152 |
-
self.key = nn.Conv1d(channels, channels, 1)
|
153 |
-
|
154 |
-
if nfreqs: self.query_freqs = nn.Conv1d(channels, heads * nfreqs, 1)
|
155 |
-
|
156 |
-
if ndecay:
|
157 |
-
self.query_decay = nn.Conv1d(channels, heads * ndecay, 1)
|
158 |
-
self.query_decay.weight.data *= 0.01
|
159 |
-
|
160 |
-
assert self.query_decay.bias is not None
|
161 |
-
|
162 |
-
self.query_decay.bias.data[:] = -2
|
163 |
-
|
164 |
-
self.proj = nn.Conv1d(channels + heads * nfreqs, channels, 1)
|
165 |
-
|
166 |
-
def forward(self, x):
|
167 |
-
B, C, T = x.shape
|
168 |
-
heads = self.heads
|
169 |
-
indexes = torch.arange(T, device=x.device, dtype=x.dtype)
|
170 |
-
delta = indexes[:, None] - indexes[None, :]
|
171 |
-
|
172 |
-
queries = self.query(x).view(B, heads, -1, T)
|
173 |
-
keys = self.key(x).view(B, heads, -1, T)
|
174 |
-
|
175 |
-
dots = torch.einsum("bhct,bhcs->bhts", keys, queries)
|
176 |
-
dots /= keys.shape[2] ** 0.5
|
177 |
-
|
178 |
-
|
179 |
-
if self.nfreqs:
|
180 |
-
periods = torch.arange(1, self.nfreqs + 1, device=x.device, dtype=x.dtype)
|
181 |
-
freq_kernel = torch.cos(2 * math.pi * delta / periods.view(-1, 1, 1))
|
182 |
-
freq_q = self.query_freqs(x).view(B, heads, -1, T) / self.nfreqs**0.5
|
183 |
-
dots += torch.einsum("fts,bhfs->bhts", freq_kernel, freq_q)
|
184 |
-
|
185 |
-
if self.ndecay:
|
186 |
-
decays = torch.arange(1, self.ndecay + 1, device=x.device, dtype=x.dtype)
|
187 |
-
decay_q = self.query_decay(x).view(B, heads, -1, T)
|
188 |
-
decay_q = torch.sigmoid(decay_q) / 2
|
189 |
-
decay_kernel = -decays.view(-1, 1, 1) * delta.abs() / self.ndecay**0.5
|
190 |
-
dots += torch.einsum("fts,bhfs->bhts", decay_kernel, decay_q)
|
191 |
-
|
192 |
-
|
193 |
-
dots.masked_fill_(torch.eye(T, device=dots.device, dtype=torch.bool), -100)
|
194 |
-
weights = torch.softmax(dots, dim=2)
|
195 |
-
|
196 |
-
content = self.content(x).view(B, heads, -1, T)
|
197 |
-
result = torch.einsum("bhts,bhct->bhcs", weights, content)
|
198 |
-
|
199 |
-
if self.nfreqs:
|
200 |
-
time_sig = torch.einsum("bhts,fts->bhfs", weights, freq_kernel)
|
201 |
-
result = torch.cat([result, time_sig], 2)
|
202 |
-
|
203 |
-
result = result.reshape(B, -1, T)
|
204 |
-
return x + self.proj(result)
|
205 |
-
|
206 |
-
class Demucs(nn.Module):
|
207 |
-
@capture_init
|
208 |
-
def __init__(self, sources, audio_channels=2, channels=64, growth=2.0, depth=6, rewrite=True, lstm_layers=0, kernel_size=8, stride=4, context=1, gelu=True, glu=True, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, normalize=True, resample=True, rescale=0.1, samplerate=44100, segment=4 * 10):
|
209 |
-
super().__init__()
|
210 |
-
self.audio_channels = audio_channels
|
211 |
-
self.sources = sources
|
212 |
-
self.kernel_size = kernel_size
|
213 |
-
self.context = context
|
214 |
-
self.stride = stride
|
215 |
-
self.depth = depth
|
216 |
-
self.resample = resample
|
217 |
-
self.channels = channels
|
218 |
-
self.normalize = normalize
|
219 |
-
self.samplerate = samplerate
|
220 |
-
self.segment = segment
|
221 |
-
|
222 |
-
self.encoder = nn.ModuleList()
|
223 |
-
self.decoder = nn.ModuleList()
|
224 |
-
self.skip_scales = nn.ModuleList()
|
225 |
-
|
226 |
-
if glu:
|
227 |
-
activation = nn.GLU(dim=1)
|
228 |
-
ch_scale = 2
|
229 |
-
else:
|
230 |
-
activation = nn.ReLU()
|
231 |
-
ch_scale = 1
|
232 |
-
|
233 |
-
|
234 |
-
act2 = nn.GELU if gelu else nn.ReLU
|
235 |
-
|
236 |
-
in_channels = audio_channels
|
237 |
-
padding = 0
|
238 |
-
|
239 |
-
|
240 |
-
for index in range(depth):
|
241 |
-
norm_fn = lambda d: nn.Identity()
|
242 |
-
|
243 |
-
if index >= norm_starts: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
244 |
-
|
245 |
-
encode = []
|
246 |
-
encode += [nn.Conv1d(in_channels, channels, kernel_size, stride), norm_fn(channels), act2()]
|
247 |
-
|
248 |
-
attn = index >= dconv_attn
|
249 |
-
lstm = index >= dconv_lstm
|
250 |
-
|
251 |
-
if dconv_mode & 1: encode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
252 |
-
if rewrite: encode += [nn.Conv1d(channels, ch_scale * channels, 1), norm_fn(ch_scale * channels), activation]
|
253 |
-
|
254 |
-
self.encoder.append(nn.Sequential(*encode))
|
255 |
-
|
256 |
-
decode = []
|
257 |
-
|
258 |
-
out_channels = in_channels if index > 0 else len(self.sources) * audio_channels
|
259 |
-
|
260 |
-
if rewrite: decode += [nn.Conv1d(channels, ch_scale * channels, 2 * context + 1, padding=context), norm_fn(ch_scale * channels), activation]
|
261 |
-
if dconv_mode & 2: decode += [DConv(channels, depth=dconv_depth, init=dconv_init, compress=dconv_comp, attn=attn, lstm=lstm)]
|
262 |
-
|
263 |
-
decode += [nn.ConvTranspose1d(channels, out_channels, kernel_size, stride, padding=padding)]
|
264 |
-
|
265 |
-
if index > 0: decode += [norm_fn(out_channels), act2()]
|
266 |
-
|
267 |
-
self.decoder.insert(0, nn.Sequential(*decode))
|
268 |
-
in_channels = channels
|
269 |
-
channels = int(growth * channels)
|
270 |
-
|
271 |
-
|
272 |
-
channels = in_channels
|
273 |
-
|
274 |
-
self.lstm = BLSTM(channels, lstm_layers) if lstm_layers else None
|
275 |
-
|
276 |
-
|
277 |
-
if rescale: rescale_module(self, reference=rescale)
|
278 |
-
|
279 |
-
def valid_length(self, length):
|
280 |
-
if self.resample: length *= 2
|
281 |
-
|
282 |
-
for _ in range(self.depth):
|
283 |
-
length = math.ceil((length - self.kernel_size) / self.stride) + 1
|
284 |
-
length = max(1, length)
|
285 |
-
|
286 |
-
for _ in range(self.depth):
|
287 |
-
length = (length - 1) * self.stride + self.kernel_size
|
288 |
-
|
289 |
-
if self.resample: length = math.ceil(length / 2)
|
290 |
-
|
291 |
-
return int(length)
|
292 |
-
|
293 |
-
def forward(self, mix):
|
294 |
-
x = mix
|
295 |
-
length = x.shape[-1]
|
296 |
-
|
297 |
-
if self.normalize:
|
298 |
-
mono = mix.mean(dim=1, keepdim=True)
|
299 |
-
mean = mono.mean(dim=-1, keepdim=True)
|
300 |
-
std = mono.std(dim=-1, keepdim=True)
|
301 |
-
x = (x - mean) / (1e-5 + std)
|
302 |
-
else:
|
303 |
-
mean = 0
|
304 |
-
std = 1
|
305 |
-
|
306 |
-
delta = self.valid_length(length) - length
|
307 |
-
x = F.pad(x, (delta // 2, delta - delta // 2))
|
308 |
-
|
309 |
-
if self.resample: x = julius.resample_frac(x, 1, 2)
|
310 |
-
|
311 |
-
saved = []
|
312 |
-
|
313 |
-
for encode in self.encoder:
|
314 |
-
x = encode(x)
|
315 |
-
saved.append(x)
|
316 |
-
|
317 |
-
if self.lstm: x = self.lstm(x)
|
318 |
-
|
319 |
-
for decode in self.decoder:
|
320 |
-
skip = saved.pop(-1)
|
321 |
-
skip = center_trim(skip, x)
|
322 |
-
x = decode(x + skip)
|
323 |
-
|
324 |
-
if self.resample: x = julius.resample_frac(x, 2, 1)
|
325 |
-
|
326 |
-
x = x * std + mean
|
327 |
-
x = center_trim(x, length)
|
328 |
-
x = x.view(x.size(0), len(self.sources), self.audio_channels, x.size(-1))
|
329 |
-
|
330 |
-
return x
|
331 |
-
|
332 |
-
def load_state_dict(self, state, strict=True):
|
333 |
-
for idx in range(self.depth):
|
334 |
-
for a in ["encoder", "decoder"]:
|
335 |
-
for b in ["bias", "weight"]:
|
336 |
-
new = f"{a}.{idx}.3.{b}"
|
337 |
-
old = f"{a}.{idx}.2.{b}"
|
338 |
-
|
339 |
-
if old in state and new not in state: state[new] = state.pop(old)
|
340 |
-
super().load_state_dict(state, strict=strict)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/hdemucs.py
DELETED
@@ -1,850 +0,0 @@
|
|
1 |
-
import math
|
2 |
-
import torch
|
3 |
-
|
4 |
-
import typing as tp
|
5 |
-
|
6 |
-
from torch import nn
|
7 |
-
from copy import deepcopy
|
8 |
-
from typing import Optional
|
9 |
-
|
10 |
-
from torch.nn import functional as F
|
11 |
-
|
12 |
-
from .states import capture_init
|
13 |
-
from .demucs import DConv, rescale_module
|
14 |
-
|
15 |
-
|
16 |
-
def spectro(x, n_fft=512, hop_length=None, pad=0):
|
17 |
-
*other, length = x.shape
|
18 |
-
x = x.reshape(-1, length)
|
19 |
-
device_type = x.device.type
|
20 |
-
is_other_gpu = not device_type in ["cuda", "cpu"]
|
21 |
-
|
22 |
-
if is_other_gpu: x = x.cpu()
|
23 |
-
|
24 |
-
z = torch.stft(x, n_fft * (1 + pad), hop_length or n_fft // 4, window=torch.hann_window(n_fft).to(x), win_length=n_fft, normalized=True, center=True, return_complex=True, pad_mode="reflect")
|
25 |
-
_, freqs, frame = z.shape
|
26 |
-
|
27 |
-
return z.view(*other, freqs, frame)
|
28 |
-
|
29 |
-
|
30 |
-
def ispectro(z, hop_length=None, length=None, pad=0):
|
31 |
-
*other, freqs, frames = z.shape
|
32 |
-
n_fft = 2 * freqs - 2
|
33 |
-
z = z.view(-1, freqs, frames)
|
34 |
-
win_length = n_fft // (1 + pad)
|
35 |
-
device_type = z.device.type
|
36 |
-
is_other_gpu = not device_type in ["cuda", "cpu"]
|
37 |
-
|
38 |
-
if is_other_gpu: z = z.cpu()
|
39 |
-
|
40 |
-
x = torch.istft(z, n_fft, hop_length, window=torch.hann_window(win_length).to(z.real), win_length=win_length, normalized=True, length=length, center=True)
|
41 |
-
_, length = x.shape
|
42 |
-
|
43 |
-
return x.view(*other, length)
|
44 |
-
|
45 |
-
|
46 |
-
def atan2(y, x):
|
47 |
-
pi = 2 * torch.asin(torch.tensor(1.0))
|
48 |
-
x += ((x == 0) & (y == 0)) * 1.0
|
49 |
-
|
50 |
-
out = torch.atan(y / x)
|
51 |
-
out += ((y >= 0) & (x < 0)) * pi
|
52 |
-
out -= ((y < 0) & (x < 0)) * pi
|
53 |
-
out *= 1 - ((y > 0) & (x == 0)) * 1.0
|
54 |
-
out += ((y > 0) & (x == 0)) * (pi / 2)
|
55 |
-
out *= 1 - ((y < 0) & (x == 0)) * 1.0
|
56 |
-
out += ((y < 0) & (x == 0)) * (-pi / 2)
|
57 |
-
|
58 |
-
return out
|
59 |
-
|
60 |
-
|
61 |
-
def _norm(x: torch.Tensor) -> torch.Tensor:
|
62 |
-
return torch.abs(x[..., 0]) ** 2 + torch.abs(x[..., 1]) ** 2
|
63 |
-
|
64 |
-
|
65 |
-
def _mul_add(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
66 |
-
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
67 |
-
|
68 |
-
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
69 |
-
|
70 |
-
if out is a:
|
71 |
-
real_a = a[..., 0]
|
72 |
-
out[..., 0] = out[..., 0] + (real_a * b[..., 0] - a[..., 1] * b[..., 1])
|
73 |
-
out[..., 1] = out[..., 1] + (real_a * b[..., 1] + a[..., 1] * b[..., 0])
|
74 |
-
else:
|
75 |
-
out[..., 0] = out[..., 0] + (a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1])
|
76 |
-
out[..., 1] = out[..., 1] + (a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0])
|
77 |
-
|
78 |
-
return out
|
79 |
-
|
80 |
-
|
81 |
-
def _mul(a: torch.Tensor, b: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
82 |
-
target_shape = torch.Size([max(sa, sb) for (sa, sb) in zip(a.shape, b.shape)])
|
83 |
-
|
84 |
-
if out is None or out.shape != target_shape: out = torch.zeros(target_shape, dtype=a.dtype, device=a.device)
|
85 |
-
|
86 |
-
if out is a:
|
87 |
-
real_a = a[..., 0]
|
88 |
-
out[..., 0] = real_a * b[..., 0] - a[..., 1] * b[..., 1]
|
89 |
-
out[..., 1] = real_a * b[..., 1] + a[..., 1] * b[..., 0]
|
90 |
-
else:
|
91 |
-
out[..., 0] = a[..., 0] * b[..., 0] - a[..., 1] * b[..., 1]
|
92 |
-
out[..., 1] = a[..., 0] * b[..., 1] + a[..., 1] * b[..., 0]
|
93 |
-
|
94 |
-
return out
|
95 |
-
|
96 |
-
|
97 |
-
def _inv(z: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
98 |
-
ez = _norm(z)
|
99 |
-
|
100 |
-
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
|
101 |
-
|
102 |
-
out[..., 0] = z[..., 0] / ez
|
103 |
-
out[..., 1] = -z[..., 1] / ez
|
104 |
-
|
105 |
-
return out
|
106 |
-
|
107 |
-
|
108 |
-
def _conj(z, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
109 |
-
if out is None or out.shape != z.shape: out = torch.zeros_like(z)
|
110 |
-
|
111 |
-
out[..., 0] = z[..., 0]
|
112 |
-
out[..., 1] = -z[..., 1]
|
113 |
-
|
114 |
-
return out
|
115 |
-
|
116 |
-
|
117 |
-
def _invert(M: torch.Tensor, out: Optional[torch.Tensor] = None) -> torch.Tensor:
|
118 |
-
nb_channels = M.shape[-2]
|
119 |
-
|
120 |
-
if out is None or out.shape != M.shape: out = torch.empty_like(M)
|
121 |
-
|
122 |
-
if nb_channels == 1: out = _inv(M, out)
|
123 |
-
elif nb_channels == 2:
|
124 |
-
det = _mul(M[..., 0, 0, :], M[..., 1, 1, :])
|
125 |
-
det = det - _mul(M[..., 0, 1, :], M[..., 1, 0, :])
|
126 |
-
invDet = _inv(det)
|
127 |
-
|
128 |
-
out[..., 0, 0, :] = _mul(invDet, M[..., 1, 1, :], out[..., 0, 0, :])
|
129 |
-
out[..., 1, 0, :] = _mul(-invDet, M[..., 1, 0, :], out[..., 1, 0, :])
|
130 |
-
out[..., 0, 1, :] = _mul(-invDet, M[..., 0, 1, :], out[..., 0, 1, :])
|
131 |
-
out[..., 1, 1, :] = _mul(invDet, M[..., 0, 0, :], out[..., 1, 1, :])
|
132 |
-
else: raise Exception("Torch == 2 Channels")
|
133 |
-
|
134 |
-
return out
|
135 |
-
|
136 |
-
|
137 |
-
def expectation_maximization(y: torch.Tensor, x: torch.Tensor, iterations: int = 2, eps: float = 1e-10, batch_size: int = 200):
|
138 |
-
(nb_frames, nb_bins, nb_channels) = x.shape[:-1]
|
139 |
-
nb_sources = y.shape[-1]
|
140 |
-
|
141 |
-
regularization = torch.cat((torch.eye(nb_channels, dtype=x.dtype, device=x.device)[..., None], torch.zeros((nb_channels, nb_channels, 1), dtype=x.dtype, device=x.device)), dim=2)
|
142 |
-
regularization = torch.sqrt(torch.as_tensor(eps)) * (regularization[None, None, ...].expand((-1, nb_bins, -1, -1, -1)))
|
143 |
-
|
144 |
-
R = [torch.zeros((nb_bins, nb_channels, nb_channels, 2), dtype=x.dtype, device=x.device) for j in range(nb_sources)]
|
145 |
-
weight: torch.Tensor = torch.zeros((nb_bins,), dtype=x.dtype, device=x.device)
|
146 |
-
|
147 |
-
v: torch.Tensor = torch.zeros((nb_frames, nb_bins, nb_sources), dtype=x.dtype, device=x.device)
|
148 |
-
|
149 |
-
for _ in range(iterations):
|
150 |
-
v = torch.mean(torch.abs(y[..., 0, :]) ** 2 + torch.abs(y[..., 1, :]) ** 2, dim=-2)
|
151 |
-
|
152 |
-
for j in range(nb_sources):
|
153 |
-
R[j] = torch.tensor(0.0, device=x.device)
|
154 |
-
|
155 |
-
weight = torch.tensor(eps, device=x.device)
|
156 |
-
pos: int = 0
|
157 |
-
|
158 |
-
batch_size = batch_size if batch_size else nb_frames
|
159 |
-
|
160 |
-
while pos < nb_frames:
|
161 |
-
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
162 |
-
pos = int(t[-1]) + 1
|
163 |
-
|
164 |
-
R[j] = R[j] + torch.sum(_covariance(y[t, ..., j]), dim=0)
|
165 |
-
weight = weight + torch.sum(v[t, ..., j], dim=0)
|
166 |
-
|
167 |
-
R[j] = R[j] / weight[..., None, None, None]
|
168 |
-
weight = torch.zeros_like(weight)
|
169 |
-
|
170 |
-
if y.requires_grad: y = y.clone()
|
171 |
-
|
172 |
-
pos = 0
|
173 |
-
|
174 |
-
while pos < nb_frames:
|
175 |
-
t = torch.arange(pos, min(nb_frames, pos + batch_size))
|
176 |
-
pos = int(t[-1]) + 1
|
177 |
-
|
178 |
-
y[t, ...] = torch.tensor(0.0, device=x.device, dtype=x.dtype)
|
179 |
-
|
180 |
-
Cxx = regularization
|
181 |
-
|
182 |
-
for j in range(nb_sources):
|
183 |
-
Cxx = Cxx + (v[t, ..., j, None, None, None] * R[j][None, ...].clone())
|
184 |
-
|
185 |
-
inv_Cxx = _invert(Cxx)
|
186 |
-
|
187 |
-
for j in range(nb_sources):
|
188 |
-
gain = torch.zeros_like(inv_Cxx)
|
189 |
-
|
190 |
-
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels), torch.arange(nb_channels))
|
191 |
-
|
192 |
-
for index in indices:
|
193 |
-
gain[:, :, index[0], index[1], :] = _mul_add(R[j][None, :, index[0], index[2], :].clone(), inv_Cxx[:, :, index[2], index[1], :], gain[:, :, index[0], index[1], :])
|
194 |
-
|
195 |
-
gain = gain * v[t, ..., None, None, None, j]
|
196 |
-
|
197 |
-
for i in range(nb_channels):
|
198 |
-
y[t, ..., j] = _mul_add(gain[..., i, :], x[t, ..., i, None, :], y[t, ..., j])
|
199 |
-
|
200 |
-
return y, v, R
|
201 |
-
|
202 |
-
|
203 |
-
def wiener(targets_spectrograms: torch.Tensor, mix_stft: torch.Tensor, iterations: int = 1, softmask: bool = False, residual: bool = False, scale_factor: float = 10.0, eps: float = 1e-10):
|
204 |
-
if softmask: y = mix_stft[..., None] * (targets_spectrograms / (eps + torch.sum(targets_spectrograms, dim=-1, keepdim=True).to(mix_stft.dtype)))[..., None, :]
|
205 |
-
else:
|
206 |
-
angle = atan2(mix_stft[..., 1], mix_stft[..., 0])[..., None]
|
207 |
-
nb_sources = targets_spectrograms.shape[-1]
|
208 |
-
y = torch.zeros(mix_stft.shape + (nb_sources,), dtype=mix_stft.dtype, device=mix_stft.device)
|
209 |
-
y[..., 0, :] = targets_spectrograms * torch.cos(angle)
|
210 |
-
y[..., 1, :] = targets_spectrograms * torch.sin(angle)
|
211 |
-
|
212 |
-
if residual: y = torch.cat([y, mix_stft[..., None] - y.sum(dim=-1, keepdim=True)], dim=-1)
|
213 |
-
if iterations == 0: return y
|
214 |
-
|
215 |
-
max_abs = torch.max(torch.as_tensor(1.0, dtype=mix_stft.dtype, device=mix_stft.device), torch.sqrt(_norm(mix_stft)).max() / scale_factor)
|
216 |
-
|
217 |
-
mix_stft = mix_stft / max_abs
|
218 |
-
y = y / max_abs
|
219 |
-
|
220 |
-
y = expectation_maximization(y, mix_stft, iterations, eps=eps)[0]
|
221 |
-
|
222 |
-
y = y * max_abs
|
223 |
-
return y
|
224 |
-
|
225 |
-
|
226 |
-
def _covariance(y_j):
|
227 |
-
(nb_frames, nb_bins, nb_channels) = y_j.shape[:-1]
|
228 |
-
|
229 |
-
Cj = torch.zeros((nb_frames, nb_bins, nb_channels, nb_channels, 2), dtype=y_j.dtype, device=y_j.device)
|
230 |
-
indices = torch.cartesian_prod(torch.arange(nb_channels), torch.arange(nb_channels))
|
231 |
-
|
232 |
-
for index in indices:
|
233 |
-
Cj[:, :, index[0], index[1], :] = _mul_add(y_j[:, :, index[0], :], _conj(y_j[:, :, index[1], :]), Cj[:, :, index[0], index[1], :])
|
234 |
-
|
235 |
-
return Cj
|
236 |
-
|
237 |
-
|
238 |
-
def pad1d(x: torch.Tensor, paddings: tp.Tuple[int, int], mode: str = "constant", value: float = 0.0):
|
239 |
-
x0 = x
|
240 |
-
length = x.shape[-1]
|
241 |
-
padding_left, padding_right = paddings
|
242 |
-
|
243 |
-
if mode == "reflect":
|
244 |
-
max_pad = max(padding_left, padding_right)
|
245 |
-
|
246 |
-
if length <= max_pad:
|
247 |
-
extra_pad = max_pad - length + 1
|
248 |
-
extra_pad_right = min(padding_right, extra_pad)
|
249 |
-
extra_pad_left = extra_pad - extra_pad_right
|
250 |
-
paddings = (padding_left - extra_pad_left, padding_right - extra_pad_right)
|
251 |
-
x = F.pad(x, (extra_pad_left, extra_pad_right))
|
252 |
-
|
253 |
-
out = F.pad(x, paddings, mode, value)
|
254 |
-
|
255 |
-
assert out.shape[-1] == length + padding_left + padding_right
|
256 |
-
assert (out[..., padding_left : padding_left + length] == x0).all()
|
257 |
-
|
258 |
-
return out
|
259 |
-
|
260 |
-
|
261 |
-
class ScaledEmbedding(nn.Module):
|
262 |
-
def __init__(self, num_embeddings: int, embedding_dim: int, scale: float = 10.0, smooth=False):
|
263 |
-
super().__init__()
|
264 |
-
|
265 |
-
self.embedding = nn.Embedding(num_embeddings, embedding_dim)
|
266 |
-
|
267 |
-
if smooth:
|
268 |
-
weight = torch.cumsum(self.embedding.weight.data, dim=0)
|
269 |
-
weight = weight / torch.arange(1, num_embeddings + 1).to(weight).sqrt()[:, None]
|
270 |
-
|
271 |
-
self.embedding.weight.data[:] = weight
|
272 |
-
|
273 |
-
self.embedding.weight.data /= scale
|
274 |
-
self.scale = scale
|
275 |
-
|
276 |
-
@property
|
277 |
-
def weight(self):
|
278 |
-
return self.embedding.weight * self.scale
|
279 |
-
|
280 |
-
def forward(self, x):
|
281 |
-
return self.embedding(x) * self.scale
|
282 |
-
|
283 |
-
|
284 |
-
class HEncLayer(nn.Module):
|
285 |
-
def __init__(self, chin, chout, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=0, dconv_kw={}, pad=True, rewrite=True):
|
286 |
-
super().__init__()
|
287 |
-
norm_fn = lambda d: nn.Identity()
|
288 |
-
|
289 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
290 |
-
|
291 |
-
pad = kernel_size // 4 if pad else 0
|
292 |
-
|
293 |
-
klass = nn.Conv1d
|
294 |
-
self.freq = freq
|
295 |
-
self.kernel_size = kernel_size
|
296 |
-
self.stride = stride
|
297 |
-
self.empty = empty
|
298 |
-
self.norm = norm
|
299 |
-
self.pad = pad
|
300 |
-
|
301 |
-
if freq:
|
302 |
-
kernel_size = [kernel_size, 1]
|
303 |
-
stride = [stride, 1]
|
304 |
-
pad = [pad, 0]
|
305 |
-
klass = nn.Conv2d
|
306 |
-
|
307 |
-
self.conv = klass(chin, chout, kernel_size, stride, pad)
|
308 |
-
|
309 |
-
if self.empty: return
|
310 |
-
|
311 |
-
self.norm1 = norm_fn(chout)
|
312 |
-
self.rewrite = None
|
313 |
-
|
314 |
-
if rewrite:
|
315 |
-
self.rewrite = klass(chout, 2 * chout, 1 + 2 * context, 1, context)
|
316 |
-
self.norm2 = norm_fn(2 * chout)
|
317 |
-
|
318 |
-
self.dconv = None
|
319 |
-
|
320 |
-
if dconv: self.dconv = DConv(chout, **dconv_kw)
|
321 |
-
|
322 |
-
def forward(self, x, inject=None):
|
323 |
-
if not self.freq and x.dim() == 4:
|
324 |
-
B, C, Fr, T = x.shape
|
325 |
-
x = x.view(B, -1, T)
|
326 |
-
|
327 |
-
if not self.freq:
|
328 |
-
le = x.shape[-1]
|
329 |
-
|
330 |
-
if not le % self.stride == 0: x = F.pad(x, (0, self.stride - (le % self.stride)))
|
331 |
-
|
332 |
-
y = self.conv(x)
|
333 |
-
|
334 |
-
if self.empty: return y
|
335 |
-
|
336 |
-
if inject is not None:
|
337 |
-
assert inject.shape[-1] == y.shape[-1], (inject.shape, y.shape)
|
338 |
-
|
339 |
-
if inject.dim() == 3 and y.dim() == 4: inject = inject[:, :, None]
|
340 |
-
|
341 |
-
y = y + inject
|
342 |
-
|
343 |
-
y = F.gelu(self.norm1(y))
|
344 |
-
|
345 |
-
|
346 |
-
if self.dconv:
|
347 |
-
if self.freq:
|
348 |
-
B, C, Fr, T = y.shape
|
349 |
-
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
350 |
-
|
351 |
-
y = self.dconv(y)
|
352 |
-
|
353 |
-
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
354 |
-
|
355 |
-
if self.rewrite:
|
356 |
-
z = self.norm2(self.rewrite(y))
|
357 |
-
z = F.glu(z, dim=1)
|
358 |
-
else: z = y
|
359 |
-
|
360 |
-
return z
|
361 |
-
|
362 |
-
|
363 |
-
class MultiWrap(nn.Module):
|
364 |
-
def __init__(self, layer, split_ratios):
|
365 |
-
super().__init__()
|
366 |
-
|
367 |
-
self.split_ratios = split_ratios
|
368 |
-
self.layers = nn.ModuleList()
|
369 |
-
self.conv = isinstance(layer, HEncLayer)
|
370 |
-
|
371 |
-
assert not layer.norm
|
372 |
-
assert layer.freq
|
373 |
-
assert layer.pad
|
374 |
-
|
375 |
-
if not self.conv: assert not layer.context_freq
|
376 |
-
|
377 |
-
for _ in range(len(split_ratios) + 1):
|
378 |
-
lay = deepcopy(layer)
|
379 |
-
|
380 |
-
if self.conv: lay.conv.padding = (0, 0)
|
381 |
-
else: lay.pad = False
|
382 |
-
|
383 |
-
for m in lay.modules():
|
384 |
-
if hasattr(m, "reset_parameters"): m.reset_parameters()
|
385 |
-
|
386 |
-
self.layers.append(lay)
|
387 |
-
|
388 |
-
def forward(self, x, skip=None, length=None):
|
389 |
-
B, C, Fr, T = x.shape
|
390 |
-
|
391 |
-
ratios = list(self.split_ratios) + [1]
|
392 |
-
start = 0
|
393 |
-
outs = []
|
394 |
-
|
395 |
-
for ratio, layer in zip(ratios, self.layers):
|
396 |
-
if self.conv:
|
397 |
-
pad = layer.kernel_size // 4
|
398 |
-
|
399 |
-
if ratio == 1:
|
400 |
-
limit = Fr
|
401 |
-
frames = -1
|
402 |
-
else:
|
403 |
-
limit = int(round(Fr * ratio))
|
404 |
-
le = limit - start
|
405 |
-
|
406 |
-
if start == 0: le += pad
|
407 |
-
|
408 |
-
frames = round((le - layer.kernel_size) / layer.stride + 1)
|
409 |
-
limit = start + (frames - 1) * layer.stride + layer.kernel_size
|
410 |
-
|
411 |
-
if start == 0: limit -= pad
|
412 |
-
|
413 |
-
assert limit - start > 0, (limit, start)
|
414 |
-
assert limit <= Fr, (limit, Fr)
|
415 |
-
|
416 |
-
y = x[:, :, start:limit, :]
|
417 |
-
|
418 |
-
if start == 0: y = F.pad(y, (0, 0, pad, 0))
|
419 |
-
if ratio == 1: y = F.pad(y, (0, 0, 0, pad))
|
420 |
-
|
421 |
-
outs.append(layer(y))
|
422 |
-
start = limit - layer.kernel_size + layer.stride
|
423 |
-
else:
|
424 |
-
limit = Fr if ratio == 1 else int(round(Fr * ratio))
|
425 |
-
|
426 |
-
last = layer.last
|
427 |
-
layer.last = True
|
428 |
-
|
429 |
-
y = x[:, :, start:limit]
|
430 |
-
s = skip[:, :, start:limit]
|
431 |
-
out, _ = layer(y, s, None)
|
432 |
-
|
433 |
-
if outs:
|
434 |
-
outs[-1][:, :, -layer.stride :] += out[:, :, : layer.stride] - layer.conv_tr.bias.view(1, -1, 1, 1)
|
435 |
-
out = out[:, :, layer.stride :]
|
436 |
-
|
437 |
-
if ratio == 1: out = out[:, :, : -layer.stride // 2, :]
|
438 |
-
if start == 0: out = out[:, :, layer.stride // 2 :, :]
|
439 |
-
|
440 |
-
outs.append(out)
|
441 |
-
layer.last = last
|
442 |
-
start = limit
|
443 |
-
|
444 |
-
out = torch.cat(outs, dim=2)
|
445 |
-
|
446 |
-
if not self.conv and not last: out = F.gelu(out)
|
447 |
-
|
448 |
-
if self.conv: return out
|
449 |
-
else: return out, None
|
450 |
-
|
451 |
-
|
452 |
-
class HDecLayer(nn.Module):
|
453 |
-
def __init__(self, chin, chout, last=False, kernel_size=8, stride=4, norm_groups=1, empty=False, freq=True, dconv=True, norm=True, context=1, dconv_kw={}, pad=True, context_freq=True, rewrite=True):
|
454 |
-
super().__init__()
|
455 |
-
norm_fn = lambda d: nn.Identity()
|
456 |
-
|
457 |
-
if norm: norm_fn = lambda d: nn.GroupNorm(norm_groups, d)
|
458 |
-
|
459 |
-
pad = kernel_size // 4 if pad else 0
|
460 |
-
|
461 |
-
self.pad = pad
|
462 |
-
self.last = last
|
463 |
-
self.freq = freq
|
464 |
-
self.chin = chin
|
465 |
-
self.empty = empty
|
466 |
-
self.stride = stride
|
467 |
-
self.kernel_size = kernel_size
|
468 |
-
self.norm = norm
|
469 |
-
self.context_freq = context_freq
|
470 |
-
klass = nn.Conv1d
|
471 |
-
klass_tr = nn.ConvTranspose1d
|
472 |
-
|
473 |
-
if freq:
|
474 |
-
kernel_size = [kernel_size, 1]
|
475 |
-
stride = [stride, 1]
|
476 |
-
klass = nn.Conv2d
|
477 |
-
klass_tr = nn.ConvTranspose2d
|
478 |
-
|
479 |
-
self.conv_tr = klass_tr(chin, chout, kernel_size, stride)
|
480 |
-
self.norm2 = norm_fn(chout)
|
481 |
-
|
482 |
-
if self.empty: return
|
483 |
-
|
484 |
-
self.rewrite = None
|
485 |
-
|
486 |
-
if rewrite:
|
487 |
-
if context_freq: self.rewrite = klass(chin, 2 * chin, 1 + 2 * context, 1, context)
|
488 |
-
else: self.rewrite = klass(chin, 2 * chin, [1, 1 + 2 * context], 1, [0, context])
|
489 |
-
|
490 |
-
self.norm1 = norm_fn(2 * chin)
|
491 |
-
|
492 |
-
self.dconv = None
|
493 |
-
|
494 |
-
if dconv: self.dconv = DConv(chin, **dconv_kw)
|
495 |
-
|
496 |
-
def forward(self, x, skip, length):
|
497 |
-
if self.freq and x.dim() == 3:
|
498 |
-
B, C, T = x.shape
|
499 |
-
x = x.view(B, self.chin, -1, T)
|
500 |
-
|
501 |
-
if not self.empty:
|
502 |
-
x = x + skip
|
503 |
-
|
504 |
-
y = F.glu(self.norm1(self.rewrite(x)), dim=1) if self.rewrite else x
|
505 |
-
|
506 |
-
if self.dconv:
|
507 |
-
if self.freq:
|
508 |
-
B, C, Fr, T = y.shape
|
509 |
-
y = y.permute(0, 2, 1, 3).reshape(-1, C, T)
|
510 |
-
|
511 |
-
y = self.dconv(y)
|
512 |
-
|
513 |
-
if self.freq: y = y.view(B, Fr, C, T).permute(0, 2, 1, 3)
|
514 |
-
else:
|
515 |
-
y = x
|
516 |
-
assert skip is None
|
517 |
-
|
518 |
-
z = self.norm2(self.conv_tr(y))
|
519 |
-
|
520 |
-
if self.freq:
|
521 |
-
if self.pad: z = z[..., self.pad : -self.pad, :]
|
522 |
-
else:
|
523 |
-
z = z[..., self.pad : self.pad + length]
|
524 |
-
assert z.shape[-1] == length, (z.shape[-1], length)
|
525 |
-
|
526 |
-
if not self.last: z = F.gelu(z)
|
527 |
-
|
528 |
-
return z, y
|
529 |
-
|
530 |
-
|
531 |
-
class HDemucs(nn.Module):
|
532 |
-
@capture_init
|
533 |
-
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=6, rewrite=True, hybrid=True, hybrid_old=False, multi_freqs=None, multi_freqs_depth=2, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=4, dconv_attn=4, dconv_lstm=4, dconv_init=1e-4, rescale=0.1, samplerate=44100, segment=4 * 10):
|
534 |
-
super().__init__()
|
535 |
-
|
536 |
-
self.cac = cac
|
537 |
-
self.wiener_residual = wiener_residual
|
538 |
-
self.audio_channels = audio_channels
|
539 |
-
self.sources = sources
|
540 |
-
self.kernel_size = kernel_size
|
541 |
-
self.context = context
|
542 |
-
self.stride = stride
|
543 |
-
self.depth = depth
|
544 |
-
self.channels = channels
|
545 |
-
self.samplerate = samplerate
|
546 |
-
self.segment = segment
|
547 |
-
self.nfft = nfft
|
548 |
-
self.hop_length = nfft // 4
|
549 |
-
self.wiener_iters = wiener_iters
|
550 |
-
self.end_iters = end_iters
|
551 |
-
self.freq_emb = None
|
552 |
-
self.hybrid = hybrid
|
553 |
-
self.hybrid_old = hybrid_old
|
554 |
-
|
555 |
-
if hybrid_old: assert hybrid
|
556 |
-
if hybrid: assert wiener_iters == end_iters
|
557 |
-
|
558 |
-
self.encoder = nn.ModuleList()
|
559 |
-
self.decoder = nn.ModuleList()
|
560 |
-
|
561 |
-
if hybrid:
|
562 |
-
self.tencoder = nn.ModuleList()
|
563 |
-
self.tdecoder = nn.ModuleList()
|
564 |
-
|
565 |
-
chin = audio_channels
|
566 |
-
chin_z = chin
|
567 |
-
|
568 |
-
if self.cac: chin_z *= 2
|
569 |
-
|
570 |
-
chout = channels_time or channels
|
571 |
-
chout_z = channels
|
572 |
-
freqs = nfft // 2
|
573 |
-
|
574 |
-
for index in range(depth):
|
575 |
-
lstm = index >= dconv_lstm
|
576 |
-
attn = index >= dconv_attn
|
577 |
-
norm = index >= norm_starts
|
578 |
-
freq = freqs > 1
|
579 |
-
stri = stride
|
580 |
-
ker = kernel_size
|
581 |
-
|
582 |
-
if not freq:
|
583 |
-
assert freqs == 1
|
584 |
-
|
585 |
-
ker = time_stride * 2
|
586 |
-
stri = time_stride
|
587 |
-
|
588 |
-
pad = True
|
589 |
-
last_freq = False
|
590 |
-
|
591 |
-
if freq and freqs <= kernel_size:
|
592 |
-
ker = freqs
|
593 |
-
pad = False
|
594 |
-
last_freq = True
|
595 |
-
|
596 |
-
kw = {
|
597 |
-
"kernel_size": ker,
|
598 |
-
"stride": stri,
|
599 |
-
"freq": freq,
|
600 |
-
"pad": pad,
|
601 |
-
"norm": norm,
|
602 |
-
"rewrite": rewrite,
|
603 |
-
"norm_groups": norm_groups,
|
604 |
-
"dconv_kw": {"lstm": lstm, "attn": attn, "depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
605 |
-
}
|
606 |
-
|
607 |
-
kwt = dict(kw)
|
608 |
-
kwt["freq"] = 0
|
609 |
-
kwt["kernel_size"] = kernel_size
|
610 |
-
kwt["stride"] = stride
|
611 |
-
kwt["pad"] = True
|
612 |
-
kw_dec = dict(kw)
|
613 |
-
|
614 |
-
multi = False
|
615 |
-
|
616 |
-
if multi_freqs and index < multi_freqs_depth:
|
617 |
-
multi = True
|
618 |
-
kw_dec["context_freq"] = False
|
619 |
-
|
620 |
-
if last_freq:
|
621 |
-
chout_z = max(chout, chout_z)
|
622 |
-
chout = chout_z
|
623 |
-
|
624 |
-
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
625 |
-
|
626 |
-
if hybrid and freq:
|
627 |
-
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
628 |
-
self.tencoder.append(tenc)
|
629 |
-
|
630 |
-
if multi: enc = MultiWrap(enc, multi_freqs)
|
631 |
-
|
632 |
-
self.encoder.append(enc)
|
633 |
-
|
634 |
-
if index == 0:
|
635 |
-
chin = self.audio_channels * len(self.sources)
|
636 |
-
chin_z = chin
|
637 |
-
|
638 |
-
if self.cac: chin_z *= 2
|
639 |
-
|
640 |
-
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
641 |
-
|
642 |
-
if multi: dec = MultiWrap(dec, multi_freqs)
|
643 |
-
|
644 |
-
if hybrid and freq:
|
645 |
-
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
646 |
-
self.tdecoder.insert(0, tdec)
|
647 |
-
|
648 |
-
self.decoder.insert(0, dec)
|
649 |
-
|
650 |
-
chin = chout
|
651 |
-
chin_z = chout_z
|
652 |
-
|
653 |
-
chout = int(growth * chout)
|
654 |
-
chout_z = int(growth * chout_z)
|
655 |
-
|
656 |
-
if freq:
|
657 |
-
if freqs <= kernel_size: freqs = 1
|
658 |
-
else: freqs //= stride
|
659 |
-
|
660 |
-
if index == 0 and freq_emb:
|
661 |
-
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
662 |
-
self.freq_emb_scale = freq_emb
|
663 |
-
|
664 |
-
if rescale: rescale_module(self, reference=rescale)
|
665 |
-
|
666 |
-
def _spec(self, x):
|
667 |
-
hl = self.hop_length
|
668 |
-
nfft = self.nfft
|
669 |
-
|
670 |
-
if self.hybrid:
|
671 |
-
assert hl == nfft // 4
|
672 |
-
|
673 |
-
le = int(math.ceil(x.shape[-1] / hl))
|
674 |
-
pad = hl // 2 * 3
|
675 |
-
|
676 |
-
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect") if not self.hybrid_old else pad1d(x, (pad, pad + le * hl - x.shape[-1]))
|
677 |
-
|
678 |
-
z = spectro(x, nfft, hl)[..., :-1, :]
|
679 |
-
|
680 |
-
if self.hybrid:
|
681 |
-
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
682 |
-
z = z[..., 2 : 2 + le]
|
683 |
-
|
684 |
-
return z
|
685 |
-
|
686 |
-
def _ispec(self, z, length=None, scale=0):
|
687 |
-
hl = self.hop_length // (4**scale)
|
688 |
-
z = F.pad(z, (0, 0, 0, 1))
|
689 |
-
|
690 |
-
if self.hybrid:
|
691 |
-
z = F.pad(z, (2, 2))
|
692 |
-
pad = hl // 2 * 3
|
693 |
-
le = hl * int(math.ceil(length / hl)) + 2 * pad if not self.hybrid_old else hl * int(math.ceil(length / hl))
|
694 |
-
|
695 |
-
x = ispectro(z, hl, length=le)
|
696 |
-
x = x[..., pad : pad + length] if not self.hybrid_old else x[..., :length]
|
697 |
-
else: x = ispectro(z, hl, length)
|
698 |
-
|
699 |
-
return x
|
700 |
-
|
701 |
-
def _magnitude(self, z):
|
702 |
-
if self.cac:
|
703 |
-
B, C, Fr, T = z.shape
|
704 |
-
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
705 |
-
m = m.reshape(B, C * 2, Fr, T)
|
706 |
-
else: m = z.abs()
|
707 |
-
|
708 |
-
return m
|
709 |
-
|
710 |
-
def _mask(self, z, m):
|
711 |
-
niters = self.wiener_iters
|
712 |
-
|
713 |
-
if self.cac:
|
714 |
-
B, S, C, Fr, T = m.shape
|
715 |
-
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
716 |
-
out = torch.view_as_complex(out.contiguous())
|
717 |
-
return out
|
718 |
-
|
719 |
-
if self.training: niters = self.end_iters
|
720 |
-
|
721 |
-
if niters < 0:
|
722 |
-
z = z[:, None]
|
723 |
-
return z / (1e-8 + z.abs()) * m
|
724 |
-
else: return self._wiener(m, z, niters)
|
725 |
-
|
726 |
-
def _wiener(self, mag_out, mix_stft, niters):
|
727 |
-
init = mix_stft.dtype
|
728 |
-
wiener_win_len = 300
|
729 |
-
residual = self.wiener_residual
|
730 |
-
|
731 |
-
B, S, C, Fq, T = mag_out.shape
|
732 |
-
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
733 |
-
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
734 |
-
|
735 |
-
outs = []
|
736 |
-
|
737 |
-
for sample in range(B):
|
738 |
-
pos = 0
|
739 |
-
out = []
|
740 |
-
|
741 |
-
for pos in range(0, T, wiener_win_len):
|
742 |
-
frame = slice(pos, pos + wiener_win_len)
|
743 |
-
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
744 |
-
out.append(z_out.transpose(-1, -2))
|
745 |
-
|
746 |
-
outs.append(torch.cat(out, dim=0))
|
747 |
-
|
748 |
-
out = torch.view_as_complex(torch.stack(outs, 0))
|
749 |
-
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
750 |
-
|
751 |
-
if residual: out = out[:, :-1]
|
752 |
-
|
753 |
-
assert list(out.shape) == [B, S, C, Fq, T]
|
754 |
-
return out.to(init)
|
755 |
-
|
756 |
-
def forward(self, mix):
|
757 |
-
x = mix
|
758 |
-
length = x.shape[-1]
|
759 |
-
|
760 |
-
z = self._spec(mix)
|
761 |
-
mag = self._magnitude(z).to(mix.device)
|
762 |
-
x = mag
|
763 |
-
|
764 |
-
B, C, Fq, T = x.shape
|
765 |
-
|
766 |
-
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
767 |
-
std = x.std(dim=(1, 2, 3), keepdim=True)
|
768 |
-
x = (x - mean) / (1e-5 + std)
|
769 |
-
|
770 |
-
if self.hybrid:
|
771 |
-
xt = mix
|
772 |
-
meant = xt.mean(dim=(1, 2), keepdim=True)
|
773 |
-
stdt = xt.std(dim=(1, 2), keepdim=True)
|
774 |
-
xt = (xt - meant) / (1e-5 + stdt)
|
775 |
-
|
776 |
-
saved = []
|
777 |
-
saved_t = []
|
778 |
-
lengths = []
|
779 |
-
lengths_t = []
|
780 |
-
|
781 |
-
for idx, encode in enumerate(self.encoder):
|
782 |
-
lengths.append(x.shape[-1])
|
783 |
-
inject = None
|
784 |
-
|
785 |
-
if self.hybrid and idx < len(self.tencoder):
|
786 |
-
lengths_t.append(xt.shape[-1])
|
787 |
-
tenc = self.tencoder[idx]
|
788 |
-
xt = tenc(xt)
|
789 |
-
|
790 |
-
if not tenc.empty: saved_t.append(xt)
|
791 |
-
else: inject = xt
|
792 |
-
|
793 |
-
x = encode(x, inject)
|
794 |
-
|
795 |
-
if idx == 0 and self.freq_emb is not None:
|
796 |
-
frs = torch.arange(x.shape[-2], device=x.device)
|
797 |
-
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
798 |
-
x = x + self.freq_emb_scale * emb
|
799 |
-
|
800 |
-
saved.append(x)
|
801 |
-
|
802 |
-
x = torch.zeros_like(x)
|
803 |
-
|
804 |
-
if self.hybrid: xt = torch.zeros_like(x)
|
805 |
-
|
806 |
-
for idx, decode in enumerate(self.decoder):
|
807 |
-
skip = saved.pop(-1)
|
808 |
-
x, pre = decode(x, skip, lengths.pop(-1))
|
809 |
-
|
810 |
-
if self.hybrid: offset = self.depth - len(self.tdecoder)
|
811 |
-
|
812 |
-
if self.hybrid and idx >= offset:
|
813 |
-
tdec = self.tdecoder[idx - offset]
|
814 |
-
length_t = lengths_t.pop(-1)
|
815 |
-
|
816 |
-
if tdec.empty:
|
817 |
-
assert pre.shape[2] == 1, pre.shape
|
818 |
-
|
819 |
-
pre = pre[:, :, 0]
|
820 |
-
xt, _ = tdec(pre, None, length_t)
|
821 |
-
else:
|
822 |
-
skip = saved_t.pop(-1)
|
823 |
-
xt, _ = tdec(xt, skip, length_t)
|
824 |
-
|
825 |
-
assert len(saved) == 0
|
826 |
-
assert len(lengths_t) == 0
|
827 |
-
assert len(saved_t) == 0
|
828 |
-
|
829 |
-
S = len(self.sources)
|
830 |
-
|
831 |
-
x = x.view(B, S, -1, Fq, T)
|
832 |
-
x = x * std[:, None] + mean[:, None]
|
833 |
-
|
834 |
-
device_type = x.device.type
|
835 |
-
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
836 |
-
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
837 |
-
|
838 |
-
if x_is_other_gpu: x = x.cpu()
|
839 |
-
|
840 |
-
zout = self._mask(z, x)
|
841 |
-
x = self._ispec(zout, length)
|
842 |
-
|
843 |
-
if x_is_other_gpu: x = x.to(device_load)
|
844 |
-
|
845 |
-
if self.hybrid:
|
846 |
-
xt = xt.view(B, S, -1, length)
|
847 |
-
xt = xt * stdt[:, None] + meant[:, None]
|
848 |
-
x = xt + x
|
849 |
-
|
850 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/htdemucs.py
DELETED
@@ -1,690 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
import random
|
6 |
-
|
7 |
-
import numpy as np
|
8 |
-
import typing as tp
|
9 |
-
|
10 |
-
from torch import nn
|
11 |
-
from einops import rearrange
|
12 |
-
from fractions import Fraction
|
13 |
-
from torch.nn import functional as F
|
14 |
-
|
15 |
-
|
16 |
-
now_dir = os.getcwd()
|
17 |
-
sys.path.append(now_dir)
|
18 |
-
|
19 |
-
from .states import capture_init
|
20 |
-
from .demucs import rescale_module
|
21 |
-
from main.configs.config import Config
|
22 |
-
from .hdemucs import pad1d, spectro, ispectro, wiener, ScaledEmbedding, HEncLayer, MultiWrap, HDecLayer
|
23 |
-
|
24 |
-
translations = Config().translations
|
25 |
-
|
26 |
-
|
27 |
-
def create_sin_embedding(length: int, dim: int, shift: int = 0, device="cpu", max_period=10000):
|
28 |
-
assert dim % 2 == 0
|
29 |
-
|
30 |
-
pos = shift + torch.arange(length, device=device).view(-1, 1, 1)
|
31 |
-
half_dim = dim // 2
|
32 |
-
|
33 |
-
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
34 |
-
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
35 |
-
|
36 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1)
|
37 |
-
|
38 |
-
def create_2d_sin_embedding(d_model, height, width, device="cpu", max_period=10000):
|
39 |
-
if d_model % 4 != 0: raise ValueError(translations["dims"].format(dims=d_model))
|
40 |
-
|
41 |
-
pe = torch.zeros(d_model, height, width)
|
42 |
-
|
43 |
-
d_model = int(d_model / 2)
|
44 |
-
|
45 |
-
div_term = torch.exp(torch.arange(0.0, d_model, 2) * -(math.log(max_period) / d_model))
|
46 |
-
|
47 |
-
pos_w = torch.arange(0.0, width).unsqueeze(1)
|
48 |
-
pos_h = torch.arange(0.0, height).unsqueeze(1)
|
49 |
-
|
50 |
-
pe[0:d_model:2, :, :] = torch.sin(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
51 |
-
pe[1:d_model:2, :, :] = torch.cos(pos_w * div_term).transpose(0, 1).unsqueeze(1).repeat(1, height, 1)
|
52 |
-
pe[d_model::2, :, :] = torch.sin(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
53 |
-
pe[d_model + 1 :: 2, :, :] = torch.cos(pos_h * div_term).transpose(0, 1).unsqueeze(2).repeat(1, 1, width)
|
54 |
-
|
55 |
-
return pe[None, :].to(device)
|
56 |
-
|
57 |
-
def create_sin_embedding_cape( length: int, dim: int, batch_size: int, mean_normalize: bool, augment: bool, max_global_shift: float = 0.0, max_local_shift: float = 0.0, max_scale: float = 1.0, device: str = "cpu", max_period: float = 10000.0):
|
58 |
-
assert dim % 2 == 0
|
59 |
-
|
60 |
-
pos = 1.0 * torch.arange(length).view(-1, 1, 1)
|
61 |
-
pos = pos.repeat(1, batch_size, 1)
|
62 |
-
|
63 |
-
if mean_normalize: pos -= torch.nanmean(pos, dim=0, keepdim=True)
|
64 |
-
|
65 |
-
if augment:
|
66 |
-
delta = np.random.uniform(-max_global_shift, +max_global_shift, size=[1, batch_size, 1])
|
67 |
-
delta_local = np.random.uniform(-max_local_shift, +max_local_shift, size=[length, batch_size, 1])
|
68 |
-
|
69 |
-
log_lambdas = np.random.uniform(-np.log(max_scale), +np.log(max_scale), size=[1, batch_size, 1])
|
70 |
-
pos = (pos + delta + delta_local) * np.exp(log_lambdas)
|
71 |
-
|
72 |
-
pos = pos.to(device)
|
73 |
-
|
74 |
-
half_dim = dim // 2
|
75 |
-
adim = torch.arange(dim // 2, device=device).view(1, 1, -1)
|
76 |
-
phase = pos / (max_period ** (adim / (half_dim - 1)))
|
77 |
-
|
78 |
-
return torch.cat([torch.cos(phase), torch.sin(phase)], dim=-1).float()
|
79 |
-
|
80 |
-
|
81 |
-
class MyGroupNorm(nn.GroupNorm):
|
82 |
-
def __init__(self, *args, **kwargs):
|
83 |
-
super().__init__(*args, **kwargs)
|
84 |
-
|
85 |
-
def forward(self, x):
|
86 |
-
x = x.transpose(1, 2)
|
87 |
-
return super().forward(x).transpose(1, 2)
|
88 |
-
|
89 |
-
class LayerScale(nn.Module):
|
90 |
-
def __init__(self, channels: int, init: float = 0, channel_last=False):
|
91 |
-
super().__init__()
|
92 |
-
self.channel_last = channel_last
|
93 |
-
self.scale = nn.Parameter(torch.zeros(channels, requires_grad=True))
|
94 |
-
self.scale.data[:] = init
|
95 |
-
|
96 |
-
def forward(self, x):
|
97 |
-
if self.channel_last: return self.scale * x
|
98 |
-
else: return self.scale[:, None] * x
|
99 |
-
|
100 |
-
class MyTransformerEncoderLayer(nn.TransformerEncoderLayer):
|
101 |
-
def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation=F.relu, group_norm=0, norm_first=False, norm_out=False, layer_norm_eps=1e-5, layer_scale=False, init_values=1e-4, device=None, dtype=None, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, auto_sparsity=False, sparsity=0.95, batch_first=False):
|
102 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
103 |
-
super().__init__(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout, activation=activation, layer_norm_eps=layer_norm_eps, batch_first=batch_first, norm_first=norm_first, device=device, dtype=dtype)
|
104 |
-
|
105 |
-
self.auto_sparsity = auto_sparsity
|
106 |
-
|
107 |
-
if group_norm:
|
108 |
-
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
109 |
-
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
110 |
-
|
111 |
-
self.norm_out = None
|
112 |
-
|
113 |
-
if self.norm_first & norm_out: self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
114 |
-
|
115 |
-
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
116 |
-
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
117 |
-
|
118 |
-
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
119 |
-
x = src
|
120 |
-
T, B, C = x.shape
|
121 |
-
|
122 |
-
if self.norm_first:
|
123 |
-
x = x + self.gamma_1(self._sa_block(self.norm1(x), src_mask, src_key_padding_mask))
|
124 |
-
x = x + self.gamma_2(self._ff_block(self.norm2(x)))
|
125 |
-
|
126 |
-
if self.norm_out: x = self.norm_out(x)
|
127 |
-
else:
|
128 |
-
x = self.norm1(x + self.gamma_1(self._sa_block(x, src_mask, src_key_padding_mask)))
|
129 |
-
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
130 |
-
|
131 |
-
return x
|
132 |
-
|
133 |
-
class CrossTransformerEncoder(nn.Module):
|
134 |
-
def __init__(self, dim: int, emb: str = "sin", hidden_scale: float = 4.0, num_heads: int = 8, num_layers: int = 6, cross_first: bool = False, dropout: float = 0.0, max_positions: int = 1000, norm_in: bool = True, norm_in_group: bool = False, group_norm: int = False, norm_first: bool = False, norm_out: bool = False, max_period: float = 10000.0, weight_decay: float = 0.0, lr: tp.Optional[float] = None, layer_scale: bool = False, gelu: bool = True, sin_random_shift: int = 0, weight_pos_embed: float = 1.0, cape_mean_normalize: bool = True, cape_augment: bool = True, cape_glob_loc_scale: list = [5000.0, 1.0, 1.4], sparse_self_attn: bool = False, sparse_cross_attn: bool = False, mask_type: str = "diag", mask_random_seed: int = 42, sparse_attn_window: int = 500, global_window: int = 50, auto_sparsity: bool = False, sparsity: float = 0.95):
|
135 |
-
super().__init__()
|
136 |
-
assert dim % num_heads == 0
|
137 |
-
|
138 |
-
hidden_dim = int(dim * hidden_scale)
|
139 |
-
|
140 |
-
self.num_layers = num_layers
|
141 |
-
self.classic_parity = 1 if cross_first else 0
|
142 |
-
self.emb = emb
|
143 |
-
self.max_period = max_period
|
144 |
-
self.weight_decay = weight_decay
|
145 |
-
self.weight_pos_embed = weight_pos_embed
|
146 |
-
self.sin_random_shift = sin_random_shift
|
147 |
-
|
148 |
-
if emb == "cape":
|
149 |
-
self.cape_mean_normalize = cape_mean_normalize
|
150 |
-
self.cape_augment = cape_augment
|
151 |
-
self.cape_glob_loc_scale = cape_glob_loc_scale
|
152 |
-
|
153 |
-
if emb == "scaled": self.position_embeddings = ScaledEmbedding(max_positions, dim, scale=0.2)
|
154 |
-
|
155 |
-
self.lr = lr
|
156 |
-
|
157 |
-
activation: tp.Any = F.gelu if gelu else F.relu
|
158 |
-
|
159 |
-
self.norm_in: nn.Module
|
160 |
-
self.norm_in_t: nn.Module
|
161 |
-
|
162 |
-
if norm_in:
|
163 |
-
self.norm_in = nn.LayerNorm(dim)
|
164 |
-
self.norm_in_t = nn.LayerNorm(dim)
|
165 |
-
elif norm_in_group:
|
166 |
-
self.norm_in = MyGroupNorm(int(norm_in_group), dim)
|
167 |
-
self.norm_in_t = MyGroupNorm(int(norm_in_group), dim)
|
168 |
-
else:
|
169 |
-
self.norm_in = nn.Identity()
|
170 |
-
self.norm_in_t = nn.Identity()
|
171 |
-
|
172 |
-
self.layers = nn.ModuleList()
|
173 |
-
self.layers_t = nn.ModuleList()
|
174 |
-
|
175 |
-
kwargs_common = {
|
176 |
-
"d_model": dim,
|
177 |
-
"nhead": num_heads,
|
178 |
-
"dim_feedforward": hidden_dim,
|
179 |
-
"dropout": dropout,
|
180 |
-
"activation": activation,
|
181 |
-
"group_norm": group_norm,
|
182 |
-
"norm_first": norm_first,
|
183 |
-
"norm_out": norm_out,
|
184 |
-
"layer_scale": layer_scale,
|
185 |
-
"mask_type": mask_type,
|
186 |
-
"mask_random_seed": mask_random_seed,
|
187 |
-
"sparse_attn_window": sparse_attn_window,
|
188 |
-
"global_window": global_window,
|
189 |
-
"sparsity": sparsity,
|
190 |
-
"auto_sparsity": auto_sparsity,
|
191 |
-
"batch_first": True,
|
192 |
-
}
|
193 |
-
|
194 |
-
kwargs_classic_encoder = dict(kwargs_common)
|
195 |
-
kwargs_classic_encoder.update({"sparse": sparse_self_attn})
|
196 |
-
kwargs_cross_encoder = dict(kwargs_common)
|
197 |
-
kwargs_cross_encoder.update({"sparse": sparse_cross_attn})
|
198 |
-
|
199 |
-
for idx in range(num_layers):
|
200 |
-
if idx % 2 == self.classic_parity:
|
201 |
-
self.layers.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
202 |
-
self.layers_t.append(MyTransformerEncoderLayer(**kwargs_classic_encoder))
|
203 |
-
else:
|
204 |
-
self.layers.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
205 |
-
self.layers_t.append(CrossTransformerEncoderLayer(**kwargs_cross_encoder))
|
206 |
-
|
207 |
-
def forward(self, x, xt):
|
208 |
-
B, C, Fr, T1 = x.shape
|
209 |
-
|
210 |
-
pos_emb_2d = create_2d_sin_embedding(C, Fr, T1, x.device, self.max_period)
|
211 |
-
pos_emb_2d = rearrange(pos_emb_2d, "b c fr t1 -> b (t1 fr) c")
|
212 |
-
|
213 |
-
x = rearrange(x, "b c fr t1 -> b (t1 fr) c")
|
214 |
-
x = self.norm_in(x)
|
215 |
-
x = x + self.weight_pos_embed * pos_emb_2d
|
216 |
-
|
217 |
-
B, C, T2 = xt.shape
|
218 |
-
xt = rearrange(xt, "b c t2 -> b t2 c")
|
219 |
-
|
220 |
-
pos_emb = self._get_pos_embedding(T2, B, C, x.device)
|
221 |
-
pos_emb = rearrange(pos_emb, "t2 b c -> b t2 c")
|
222 |
-
|
223 |
-
xt = self.norm_in_t(xt)
|
224 |
-
xt = xt + self.weight_pos_embed * pos_emb
|
225 |
-
|
226 |
-
for idx in range(self.num_layers):
|
227 |
-
if idx % 2 == self.classic_parity:
|
228 |
-
x = self.layers[idx](x)
|
229 |
-
xt = self.layers_t[idx](xt)
|
230 |
-
else:
|
231 |
-
old_x = x
|
232 |
-
x = self.layers[idx](x, xt)
|
233 |
-
xt = self.layers_t[idx](xt, old_x)
|
234 |
-
|
235 |
-
x = rearrange(x, "b (t1 fr) c -> b c fr t1", t1=T1)
|
236 |
-
xt = rearrange(xt, "b t2 c -> b c t2")
|
237 |
-
|
238 |
-
return x, xt
|
239 |
-
|
240 |
-
def _get_pos_embedding(self, T, B, C, device):
|
241 |
-
if self.emb == "sin":
|
242 |
-
shift = random.randrange(self.sin_random_shift + 1)
|
243 |
-
pos_emb = create_sin_embedding(T, C, shift=shift, device=device, max_period=self.max_period)
|
244 |
-
elif self.emb == "cape":
|
245 |
-
if self.training: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=self.cape_augment, max_global_shift=self.cape_glob_loc_scale[0], max_local_shift=self.cape_glob_loc_scale[1], max_scale=self.cape_glob_loc_scale[2])
|
246 |
-
else: pos_emb = create_sin_embedding_cape(T, C, B, device=device, max_period=self.max_period, mean_normalize=self.cape_mean_normalize, augment=False)
|
247 |
-
|
248 |
-
elif self.emb == "scaled":
|
249 |
-
pos = torch.arange(T, device=device)
|
250 |
-
pos_emb = self.position_embeddings(pos)[:, None]
|
251 |
-
|
252 |
-
return pos_emb
|
253 |
-
|
254 |
-
def make_optim_group(self):
|
255 |
-
group = {"params": list(self.parameters()), "weight_decay": self.weight_decay}
|
256 |
-
if self.lr is not None: group["lr"] = self.lr
|
257 |
-
|
258 |
-
return group
|
259 |
-
|
260 |
-
|
261 |
-
class CrossTransformerEncoderLayer(nn.Module):
|
262 |
-
def __init__(self, d_model: int, nhead: int, dim_feedforward: int = 2048, dropout: float = 0.1, activation=F.relu, layer_norm_eps: float = 1e-5, layer_scale: bool = False, init_values: float = 1e-4, norm_first: bool = False, group_norm: bool = False, norm_out: bool = False, sparse=False, mask_type="diag", mask_random_seed=42, sparse_attn_window=500, global_window=50, sparsity=0.95, auto_sparsity=None, device=None, dtype=None, batch_first=False):
|
263 |
-
factory_kwargs = {"device": device, "dtype": dtype}
|
264 |
-
super().__init__()
|
265 |
-
|
266 |
-
self.auto_sparsity = auto_sparsity
|
267 |
-
|
268 |
-
self.cross_attn: nn.Module
|
269 |
-
self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=batch_first)
|
270 |
-
|
271 |
-
self.linear1 = nn.Linear(d_model, dim_feedforward, **factory_kwargs)
|
272 |
-
self.dropout = nn.Dropout(dropout)
|
273 |
-
self.linear2 = nn.Linear(dim_feedforward, d_model, **factory_kwargs)
|
274 |
-
|
275 |
-
self.norm_first = norm_first
|
276 |
-
|
277 |
-
self.norm1: nn.Module
|
278 |
-
self.norm2: nn.Module
|
279 |
-
self.norm3: nn.Module
|
280 |
-
|
281 |
-
if group_norm:
|
282 |
-
self.norm1 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
283 |
-
self.norm2 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
284 |
-
self.norm3 = MyGroupNorm(int(group_norm), d_model, eps=layer_norm_eps, **factory_kwargs)
|
285 |
-
else:
|
286 |
-
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
287 |
-
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
288 |
-
self.norm3 = nn.LayerNorm(d_model, eps=layer_norm_eps, **factory_kwargs)
|
289 |
-
|
290 |
-
self.norm_out = None
|
291 |
-
if self.norm_first & norm_out:
|
292 |
-
self.norm_out = MyGroupNorm(num_groups=int(norm_out), num_channels=d_model)
|
293 |
-
|
294 |
-
self.gamma_1 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
295 |
-
self.gamma_2 = LayerScale(d_model, init_values, True) if layer_scale else nn.Identity()
|
296 |
-
|
297 |
-
self.dropout1 = nn.Dropout(dropout)
|
298 |
-
self.dropout2 = nn.Dropout(dropout)
|
299 |
-
|
300 |
-
if isinstance(activation, str): self.activation = self._get_activation_fn(activation)
|
301 |
-
else: self.activation = activation
|
302 |
-
|
303 |
-
def forward(self, q, k, mask=None):
|
304 |
-
if self.norm_first:
|
305 |
-
x = q + self.gamma_1(self._ca_block(self.norm1(q), self.norm2(k), mask))
|
306 |
-
x = x + self.gamma_2(self._ff_block(self.norm3(x)))
|
307 |
-
|
308 |
-
if self.norm_out: x = self.norm_out(x)
|
309 |
-
else:
|
310 |
-
x = self.norm1(q + self.gamma_1(self._ca_block(q, k, mask)))
|
311 |
-
x = self.norm2(x + self.gamma_2(self._ff_block(x)))
|
312 |
-
|
313 |
-
return x
|
314 |
-
|
315 |
-
def _ca_block(self, q, k, attn_mask=None):
|
316 |
-
x = self.cross_attn(q, k, k, attn_mask=attn_mask, need_weights=False)[0]
|
317 |
-
return self.dropout1(x)
|
318 |
-
|
319 |
-
def _ff_block(self, x):
|
320 |
-
x = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
321 |
-
return self.dropout2(x)
|
322 |
-
|
323 |
-
def _get_activation_fn(self, activation):
|
324 |
-
if activation == "relu": return F.relu
|
325 |
-
elif activation == "gelu": return F.gelu
|
326 |
-
|
327 |
-
raise RuntimeError(translations["activation"].format(activation=activation))
|
328 |
-
|
329 |
-
|
330 |
-
class HTDemucs(nn.Module):
|
331 |
-
@capture_init
|
332 |
-
def __init__(self, sources, audio_channels=2, channels=48, channels_time=None, growth=2, nfft=4096, wiener_iters=0, end_iters=0, wiener_residual=False, cac=True, depth=4, rewrite=True, multi_freqs=None, multi_freqs_depth=3, freq_emb=0.2, emb_scale=10, emb_smooth=True, kernel_size=8, time_stride=2, stride=4, context=1, context_enc=0, norm_starts=4, norm_groups=4, dconv_mode=1, dconv_depth=2, dconv_comp=8, dconv_init=1e-3, bottom_channels=0, t_layers=5, t_emb="sin", t_hidden_scale=4.0, t_heads=8, t_dropout=0.0, t_max_positions=10000, t_norm_in=True, t_norm_in_group=False, t_group_norm=False, t_norm_first=True, t_norm_out=True, t_max_period=10000.0, t_weight_decay=0.0, t_lr=None, t_layer_scale=True, t_gelu=True, t_weight_pos_embed=1.0, t_sin_random_shift=0, t_cape_mean_normalize=True, t_cape_augment=True, t_cape_glob_loc_scale=[5000.0, 1.0, 1.4], t_sparse_self_attn=False, t_sparse_cross_attn=False, t_mask_type="diag", t_mask_random_seed=42, t_sparse_attn_window=500, t_global_window=100, t_sparsity=0.95, t_auto_sparsity=False, t_cross_first=False, rescale=0.1, samplerate=44100, segment=4 * 10, use_train_segment=True):
|
333 |
-
super().__init__()
|
334 |
-
|
335 |
-
self.cac = cac
|
336 |
-
self.wiener_residual = wiener_residual
|
337 |
-
self.audio_channels = audio_channels
|
338 |
-
self.sources = sources
|
339 |
-
self.kernel_size = kernel_size
|
340 |
-
self.context = context
|
341 |
-
self.stride = stride
|
342 |
-
self.depth = depth
|
343 |
-
self.bottom_channels = bottom_channels
|
344 |
-
self.channels = channels
|
345 |
-
self.samplerate = samplerate
|
346 |
-
self.segment = segment
|
347 |
-
self.use_train_segment = use_train_segment
|
348 |
-
self.nfft = nfft
|
349 |
-
self.hop_length = nfft // 4
|
350 |
-
self.wiener_iters = wiener_iters
|
351 |
-
self.end_iters = end_iters
|
352 |
-
self.freq_emb = None
|
353 |
-
|
354 |
-
assert wiener_iters == end_iters
|
355 |
-
|
356 |
-
self.encoder = nn.ModuleList()
|
357 |
-
self.decoder = nn.ModuleList()
|
358 |
-
self.tencoder = nn.ModuleList()
|
359 |
-
self.tdecoder = nn.ModuleList()
|
360 |
-
|
361 |
-
chin = audio_channels
|
362 |
-
chin_z = chin
|
363 |
-
|
364 |
-
if self.cac: chin_z *= 2
|
365 |
-
|
366 |
-
chout = channels_time or channels
|
367 |
-
chout_z = channels
|
368 |
-
freqs = nfft // 2
|
369 |
-
|
370 |
-
for index in range(depth):
|
371 |
-
norm = index >= norm_starts
|
372 |
-
freq = freqs > 1
|
373 |
-
stri = stride
|
374 |
-
ker = kernel_size
|
375 |
-
|
376 |
-
if not freq:
|
377 |
-
assert freqs == 1
|
378 |
-
|
379 |
-
ker = time_stride * 2
|
380 |
-
stri = time_stride
|
381 |
-
|
382 |
-
pad = True
|
383 |
-
last_freq = False
|
384 |
-
|
385 |
-
if freq and freqs <= kernel_size:
|
386 |
-
ker = freqs
|
387 |
-
pad = False
|
388 |
-
last_freq = True
|
389 |
-
|
390 |
-
kw = {
|
391 |
-
"kernel_size": ker,
|
392 |
-
"stride": stri,
|
393 |
-
"freq": freq,
|
394 |
-
"pad": pad,
|
395 |
-
"norm": norm,
|
396 |
-
"rewrite": rewrite,
|
397 |
-
"norm_groups": norm_groups,
|
398 |
-
"dconv_kw": {"depth": dconv_depth, "compress": dconv_comp, "init": dconv_init, "gelu": True},
|
399 |
-
}
|
400 |
-
|
401 |
-
kwt = dict(kw)
|
402 |
-
kwt["freq"] = 0
|
403 |
-
kwt["kernel_size"] = kernel_size
|
404 |
-
kwt["stride"] = stride
|
405 |
-
kwt["pad"] = True
|
406 |
-
kw_dec = dict(kw)
|
407 |
-
|
408 |
-
multi = False
|
409 |
-
|
410 |
-
if multi_freqs and index < multi_freqs_depth:
|
411 |
-
multi = True
|
412 |
-
kw_dec["context_freq"] = False
|
413 |
-
|
414 |
-
if last_freq:
|
415 |
-
chout_z = max(chout, chout_z)
|
416 |
-
chout = chout_z
|
417 |
-
|
418 |
-
enc = HEncLayer(chin_z, chout_z, dconv=dconv_mode & 1, context=context_enc, **kw)
|
419 |
-
|
420 |
-
if freq:
|
421 |
-
tenc = HEncLayer(chin, chout, dconv=dconv_mode & 1, context=context_enc, empty=last_freq, **kwt)
|
422 |
-
self.tencoder.append(tenc)
|
423 |
-
|
424 |
-
if multi: enc = MultiWrap(enc, multi_freqs)
|
425 |
-
|
426 |
-
self.encoder.append(enc)
|
427 |
-
|
428 |
-
if index == 0:
|
429 |
-
chin = self.audio_channels * len(self.sources)
|
430 |
-
chin_z = chin
|
431 |
-
|
432 |
-
if self.cac: chin_z *= 2
|
433 |
-
|
434 |
-
dec = HDecLayer(chout_z, chin_z, dconv=dconv_mode & 2, last=index == 0, context=context, **kw_dec)
|
435 |
-
|
436 |
-
if multi:
|
437 |
-
dec = MultiWrap(dec, multi_freqs)
|
438 |
-
|
439 |
-
if freq:
|
440 |
-
tdec = HDecLayer(chout, chin, dconv=dconv_mode & 2, empty=last_freq, last=index == 0, context=context, **kwt)
|
441 |
-
self.tdecoder.insert(0, tdec)
|
442 |
-
|
443 |
-
self.decoder.insert(0, dec)
|
444 |
-
|
445 |
-
chin = chout
|
446 |
-
chin_z = chout_z
|
447 |
-
|
448 |
-
chout = int(growth * chout)
|
449 |
-
chout_z = int(growth * chout_z)
|
450 |
-
|
451 |
-
if freq:
|
452 |
-
if freqs <= kernel_size: freqs = 1
|
453 |
-
else: freqs //= stride
|
454 |
-
|
455 |
-
if index == 0 and freq_emb:
|
456 |
-
self.freq_emb = ScaledEmbedding(freqs, chin_z, smooth=emb_smooth, scale=emb_scale)
|
457 |
-
self.freq_emb_scale = freq_emb
|
458 |
-
|
459 |
-
if rescale: rescale_module(self, reference=rescale)
|
460 |
-
|
461 |
-
transformer_channels = channels * growth ** (depth - 1)
|
462 |
-
|
463 |
-
if bottom_channels:
|
464 |
-
self.channel_upsampler = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
465 |
-
self.channel_downsampler = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
466 |
-
self.channel_upsampler_t = nn.Conv1d(transformer_channels, bottom_channels, 1)
|
467 |
-
self.channel_downsampler_t = nn.Conv1d(bottom_channels, transformer_channels, 1)
|
468 |
-
|
469 |
-
transformer_channels = bottom_channels
|
470 |
-
|
471 |
-
if t_layers > 0: self.crosstransformer = CrossTransformerEncoder(dim=transformer_channels, emb=t_emb, hidden_scale=t_hidden_scale, num_heads=t_heads, num_layers=t_layers, cross_first=t_cross_first, dropout=t_dropout, max_positions=t_max_positions, norm_in=t_norm_in, norm_in_group=t_norm_in_group, group_norm=t_group_norm, norm_first=t_norm_first, norm_out=t_norm_out, max_period=t_max_period, weight_decay=t_weight_decay, lr=t_lr, layer_scale=t_layer_scale, gelu=t_gelu, sin_random_shift=t_sin_random_shift, weight_pos_embed=t_weight_pos_embed, cape_mean_normalize=t_cape_mean_normalize, cape_augment=t_cape_augment, cape_glob_loc_scale=t_cape_glob_loc_scale, sparse_self_attn=t_sparse_self_attn, sparse_cross_attn=t_sparse_cross_attn, mask_type=t_mask_type, mask_random_seed=t_mask_random_seed, sparse_attn_window=t_sparse_attn_window, global_window=t_global_window, sparsity=t_sparsity, auto_sparsity=t_auto_sparsity)
|
472 |
-
else: self.crosstransformer = None
|
473 |
-
|
474 |
-
def _spec(self, x):
|
475 |
-
hl = self.hop_length
|
476 |
-
nfft = self.nfft
|
477 |
-
|
478 |
-
assert hl == nfft // 4
|
479 |
-
|
480 |
-
le = int(math.ceil(x.shape[-1] / hl))
|
481 |
-
pad = hl // 2 * 3
|
482 |
-
|
483 |
-
x = pad1d(x, (pad, pad + le * hl - x.shape[-1]), mode="reflect")
|
484 |
-
|
485 |
-
z = spectro(x, nfft, hl)[..., :-1, :]
|
486 |
-
|
487 |
-
assert z.shape[-1] == le + 4, (z.shape, x.shape, le)
|
488 |
-
|
489 |
-
z = z[..., 2 : 2 + le]
|
490 |
-
|
491 |
-
return z
|
492 |
-
|
493 |
-
def _ispec(self, z, length=None, scale=0):
|
494 |
-
hl = self.hop_length // (4**scale)
|
495 |
-
z = F.pad(z, (0, 0, 0, 1))
|
496 |
-
z = F.pad(z, (2, 2))
|
497 |
-
|
498 |
-
pad = hl // 2 * 3
|
499 |
-
le = hl * int(math.ceil(length / hl)) + 2 * pad
|
500 |
-
|
501 |
-
x = ispectro(z, hl, length=le)
|
502 |
-
x = x[..., pad : pad + length]
|
503 |
-
|
504 |
-
return x
|
505 |
-
|
506 |
-
def _magnitude(self, z):
|
507 |
-
if self.cac:
|
508 |
-
B, C, Fr, T = z.shape
|
509 |
-
m = torch.view_as_real(z).permute(0, 1, 4, 2, 3)
|
510 |
-
m = m.reshape(B, C * 2, Fr, T)
|
511 |
-
else: m = z.abs()
|
512 |
-
|
513 |
-
return m
|
514 |
-
|
515 |
-
def _mask(self, z, m):
|
516 |
-
niters = self.wiener_iters
|
517 |
-
if self.cac:
|
518 |
-
B, S, C, Fr, T = m.shape
|
519 |
-
out = m.view(B, S, -1, 2, Fr, T).permute(0, 1, 2, 4, 5, 3)
|
520 |
-
out = torch.view_as_complex(out.contiguous())
|
521 |
-
return out
|
522 |
-
|
523 |
-
if self.training: niters = self.end_iters
|
524 |
-
|
525 |
-
if niters < 0:
|
526 |
-
z = z[:, None]
|
527 |
-
return z / (1e-8 + z.abs()) * m
|
528 |
-
else: return self._wiener(m, z, niters)
|
529 |
-
|
530 |
-
def _wiener(self, mag_out, mix_stft, niters):
|
531 |
-
init = mix_stft.dtype
|
532 |
-
wiener_win_len = 300
|
533 |
-
residual = self.wiener_residual
|
534 |
-
|
535 |
-
B, S, C, Fq, T = mag_out.shape
|
536 |
-
mag_out = mag_out.permute(0, 4, 3, 2, 1)
|
537 |
-
mix_stft = torch.view_as_real(mix_stft.permute(0, 3, 2, 1))
|
538 |
-
|
539 |
-
outs = []
|
540 |
-
|
541 |
-
for sample in range(B):
|
542 |
-
pos = 0
|
543 |
-
out = []
|
544 |
-
|
545 |
-
for pos in range(0, T, wiener_win_len):
|
546 |
-
frame = slice(pos, pos + wiener_win_len)
|
547 |
-
z_out = wiener(mag_out[sample, frame], mix_stft[sample, frame], niters, residual=residual)
|
548 |
-
out.append(z_out.transpose(-1, -2))
|
549 |
-
|
550 |
-
outs.append(torch.cat(out, dim=0))
|
551 |
-
|
552 |
-
out = torch.view_as_complex(torch.stack(outs, 0))
|
553 |
-
out = out.permute(0, 4, 3, 2, 1).contiguous()
|
554 |
-
|
555 |
-
if residual: out = out[:, :-1]
|
556 |
-
|
557 |
-
assert list(out.shape) == [B, S, C, Fq, T]
|
558 |
-
|
559 |
-
return out.to(init)
|
560 |
-
|
561 |
-
def valid_length(self, length: int):
|
562 |
-
if not self.use_train_segment: return length
|
563 |
-
|
564 |
-
training_length = int(self.segment * self.samplerate)
|
565 |
-
if training_length < length: raise ValueError(translations["length_or_training_length"].format(length=length, training_length=training_length))
|
566 |
-
|
567 |
-
return training_length
|
568 |
-
|
569 |
-
def forward(self, mix):
|
570 |
-
length = mix.shape[-1]
|
571 |
-
length_pre_pad = None
|
572 |
-
|
573 |
-
if self.use_train_segment:
|
574 |
-
if self.training: self.segment = Fraction(mix.shape[-1], self.samplerate)
|
575 |
-
else:
|
576 |
-
training_length = int(self.segment * self.samplerate)
|
577 |
-
|
578 |
-
if mix.shape[-1] < training_length:
|
579 |
-
length_pre_pad = mix.shape[-1]
|
580 |
-
mix = F.pad(mix, (0, training_length - length_pre_pad))
|
581 |
-
|
582 |
-
z = self._spec(mix)
|
583 |
-
mag = self._magnitude(z).to(mix.device)
|
584 |
-
x = mag
|
585 |
-
|
586 |
-
B, C, Fq, T = x.shape
|
587 |
-
|
588 |
-
mean = x.mean(dim=(1, 2, 3), keepdim=True)
|
589 |
-
std = x.std(dim=(1, 2, 3), keepdim=True)
|
590 |
-
x = (x - mean) / (1e-5 + std)
|
591 |
-
|
592 |
-
xt = mix
|
593 |
-
meant = xt.mean(dim=(1, 2), keepdim=True)
|
594 |
-
stdt = xt.std(dim=(1, 2), keepdim=True)
|
595 |
-
xt = (xt - meant) / (1e-5 + stdt)
|
596 |
-
|
597 |
-
saved = []
|
598 |
-
saved_t = []
|
599 |
-
lengths = []
|
600 |
-
lengths_t = []
|
601 |
-
|
602 |
-
for idx, encode in enumerate(self.encoder):
|
603 |
-
lengths.append(x.shape[-1])
|
604 |
-
inject = None
|
605 |
-
|
606 |
-
if idx < len(self.tencoder):
|
607 |
-
lengths_t.append(xt.shape[-1])
|
608 |
-
tenc = self.tencoder[idx]
|
609 |
-
xt = tenc(xt)
|
610 |
-
|
611 |
-
if not tenc.empty: saved_t.append(xt)
|
612 |
-
else: inject = xt
|
613 |
-
|
614 |
-
x = encode(x, inject)
|
615 |
-
|
616 |
-
if idx == 0 and self.freq_emb is not None:
|
617 |
-
frs = torch.arange(x.shape[-2], device=x.device)
|
618 |
-
emb = self.freq_emb(frs).t()[None, :, :, None].expand_as(x)
|
619 |
-
x = x + self.freq_emb_scale * emb
|
620 |
-
|
621 |
-
saved.append(x)
|
622 |
-
|
623 |
-
|
624 |
-
if self.crosstransformer:
|
625 |
-
if self.bottom_channels:
|
626 |
-
b, c, f, t = x.shape
|
627 |
-
x = rearrange(x, "b c f t-> b c (f t)")
|
628 |
-
x = self.channel_upsampler(x)
|
629 |
-
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
630 |
-
|
631 |
-
xt = self.channel_upsampler_t(xt)
|
632 |
-
|
633 |
-
x, xt = self.crosstransformer(x, xt)
|
634 |
-
|
635 |
-
if self.bottom_channels:
|
636 |
-
x = rearrange(x, "b c f t-> b c (f t)")
|
637 |
-
x = self.channel_downsampler(x)
|
638 |
-
x = rearrange(x, "b c (f t)-> b c f t", f=f)
|
639 |
-
|
640 |
-
xt = self.channel_downsampler_t(xt)
|
641 |
-
|
642 |
-
for idx, decode in enumerate(self.decoder):
|
643 |
-
skip = saved.pop(-1)
|
644 |
-
x, pre = decode(x, skip, lengths.pop(-1))
|
645 |
-
|
646 |
-
offset = self.depth - len(self.tdecoder)
|
647 |
-
|
648 |
-
if idx >= offset:
|
649 |
-
tdec = self.tdecoder[idx - offset]
|
650 |
-
length_t = lengths_t.pop(-1)
|
651 |
-
|
652 |
-
if tdec.empty:
|
653 |
-
assert pre.shape[2] == 1, pre.shape
|
654 |
-
pre = pre[:, :, 0]
|
655 |
-
xt, _ = tdec(pre, None, length_t)
|
656 |
-
else:
|
657 |
-
skip = saved_t.pop(-1)
|
658 |
-
xt, _ = tdec(xt, skip, length_t)
|
659 |
-
|
660 |
-
assert len(saved) == 0
|
661 |
-
assert len(lengths_t) == 0
|
662 |
-
assert len(saved_t) == 0
|
663 |
-
|
664 |
-
|
665 |
-
S = len(self.sources)
|
666 |
-
x = x.view(B, S, -1, Fq, T)
|
667 |
-
x = x * std[:, None] + mean[:, None]
|
668 |
-
|
669 |
-
device_type = x.device.type
|
670 |
-
device_load = f"{device_type}:{x.device.index}" if not device_type == "mps" else device_type
|
671 |
-
x_is_other_gpu = not device_type in ["cuda", "cpu"]
|
672 |
-
|
673 |
-
if x_is_other_gpu: x = x.cpu()
|
674 |
-
|
675 |
-
zout = self._mask(z, x)
|
676 |
-
|
677 |
-
if self.use_train_segment: x = self._ispec(zout, length) if self.training else self._ispec(zout, training_length)
|
678 |
-
else: x = self._ispec(zout, length)
|
679 |
-
|
680 |
-
if x_is_other_gpu: x = x.to(device_load)
|
681 |
-
|
682 |
-
if self.use_train_segment: xt = xt.view(B, S, -1, length) if self.training else xt.view(B, S, -1, training_length)
|
683 |
-
else: xt = xt.view(B, S, -1, length)
|
684 |
-
|
685 |
-
xt = xt * stdt[:, None] + meant[:, None]
|
686 |
-
x = xt + x
|
687 |
-
|
688 |
-
if length_pre_pad: x = x[..., :length_pre_pad]
|
689 |
-
|
690 |
-
return x
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/states.py
DELETED
@@ -1,70 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import torch
|
4 |
-
import inspect
|
5 |
-
import warnings
|
6 |
-
import functools
|
7 |
-
|
8 |
-
from pathlib import Path
|
9 |
-
from diffq import restore_quantized_state
|
10 |
-
|
11 |
-
now_dir = os.getcwd()
|
12 |
-
sys.path.append(now_dir)
|
13 |
-
|
14 |
-
from main.configs.config import Config
|
15 |
-
|
16 |
-
translations = Config().translations
|
17 |
-
|
18 |
-
|
19 |
-
def load_model(path_or_package, strict=False):
|
20 |
-
if isinstance(path_or_package, dict): package = path_or_package
|
21 |
-
elif isinstance(path_or_package, (str, Path)):
|
22 |
-
with warnings.catch_warnings():
|
23 |
-
warnings.simplefilter("ignore")
|
24 |
-
|
25 |
-
path = path_or_package
|
26 |
-
package = torch.load(path, map_location="cpu")
|
27 |
-
else: raise ValueError(f"{translations['type_not_valid']} {path_or_package}.")
|
28 |
-
|
29 |
-
|
30 |
-
klass = package["klass"]
|
31 |
-
args = package["args"]
|
32 |
-
kwargs = package["kwargs"]
|
33 |
-
|
34 |
-
|
35 |
-
if strict: model = klass(*args, **kwargs)
|
36 |
-
else:
|
37 |
-
sig = inspect.signature(klass)
|
38 |
-
|
39 |
-
for key in list(kwargs):
|
40 |
-
if key not in sig.parameters:
|
41 |
-
warnings.warn(translations["del_parameter"] + key)
|
42 |
-
|
43 |
-
del kwargs[key]
|
44 |
-
|
45 |
-
model = klass(*args, **kwargs)
|
46 |
-
|
47 |
-
state = package["state"]
|
48 |
-
|
49 |
-
set_state(model, state)
|
50 |
-
|
51 |
-
return model
|
52 |
-
|
53 |
-
|
54 |
-
def set_state(model, state, quantizer=None):
|
55 |
-
if state.get("__quantized"):
|
56 |
-
if quantizer is not None: quantizer.restore_quantized_state(model, state["quantized"])
|
57 |
-
else: restore_quantized_state(model, state)
|
58 |
-
else: model.load_state_dict(state)
|
59 |
-
|
60 |
-
return state
|
61 |
-
|
62 |
-
|
63 |
-
def capture_init(init):
|
64 |
-
@functools.wraps(init)
|
65 |
-
|
66 |
-
def __init__(self, *args, **kwargs):
|
67 |
-
self._init_args_kwargs = (args, kwargs)
|
68 |
-
init(self, *args, **kwargs)
|
69 |
-
|
70 |
-
return __init__
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/demucs/utils.py
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
import typing as tp
|
3 |
-
|
4 |
-
def center_trim(tensor: torch.Tensor, reference: tp.Union[torch.Tensor, int]):
|
5 |
-
ref_size: int
|
6 |
-
ref_size = reference.size(-1) if isinstance(reference, torch.Tensor) else reference
|
7 |
-
delta = tensor.size(-1) - ref_size
|
8 |
-
if delta < 0: raise ValueError(f"tensor > parameter: {delta}.")
|
9 |
-
if delta: tensor = tensor[..., delta // 2 : -(delta - delta // 2)]
|
10 |
-
return tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/library/uvr5_separator/spec_utils.py
DELETED
@@ -1,1100 +0,0 @@
|
|
1 |
-
import io
|
2 |
-
import os
|
3 |
-
import six
|
4 |
-
import sys
|
5 |
-
import math
|
6 |
-
import librosa
|
7 |
-
import tempfile
|
8 |
-
import platform
|
9 |
-
import traceback
|
10 |
-
import audioread
|
11 |
-
import subprocess
|
12 |
-
|
13 |
-
import numpy as np
|
14 |
-
import soundfile as sf
|
15 |
-
|
16 |
-
from scipy.signal import correlate, hilbert
|
17 |
-
|
18 |
-
now_dir = os.getcwd()
|
19 |
-
sys.path.append(now_dir)
|
20 |
-
|
21 |
-
from main.configs.config import Config
|
22 |
-
translations = Config().translations
|
23 |
-
|
24 |
-
OPERATING_SYSTEM = platform.system()
|
25 |
-
SYSTEM_ARCH = platform.platform()
|
26 |
-
SYSTEM_PROC = platform.processor()
|
27 |
-
ARM = "arm"
|
28 |
-
AUTO_PHASE = "Automatic"
|
29 |
-
POSITIVE_PHASE = "Positive Phase"
|
30 |
-
NEGATIVE_PHASE = "Negative Phase"
|
31 |
-
NONE_P = ("None",)
|
32 |
-
LOW_P = ("Shifts: Low",)
|
33 |
-
MED_P = ("Shifts: Medium",)
|
34 |
-
HIGH_P = ("Shifts: High",)
|
35 |
-
VHIGH_P = "Shifts: Very High"
|
36 |
-
MAXIMUM_P = "Shifts: Maximum"
|
37 |
-
BASE_PATH_RUB = sys._MEIPASS if getattr(sys, 'frozen', False) else os.path.dirname(os.path.abspath(__file__))
|
38 |
-
DEVNULL = open(os.devnull, 'w') if six.PY2 else subprocess.DEVNULL
|
39 |
-
MAX_SPEC = "Max Spec"
|
40 |
-
MIN_SPEC = "Min Spec"
|
41 |
-
LIN_ENSE = "Linear Ensemble"
|
42 |
-
MAX_WAV = MAX_SPEC
|
43 |
-
MIN_WAV = MIN_SPEC
|
44 |
-
AVERAGE = "Average"
|
45 |
-
|
46 |
-
progress_value = 0
|
47 |
-
last_update_time = 0
|
48 |
-
is_macos = False
|
49 |
-
|
50 |
-
if OPERATING_SYSTEM == "Darwin":
|
51 |
-
wav_resolution = "polyphase" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else "sinc_fastest"
|
52 |
-
wav_resolution_float_resampling = "kaiser_best" if SYSTEM_PROC == ARM or ARM in SYSTEM_ARCH else wav_resolution
|
53 |
-
is_macos = True
|
54 |
-
else:
|
55 |
-
wav_resolution = "sinc_fastest"
|
56 |
-
wav_resolution_float_resampling = wav_resolution
|
57 |
-
|
58 |
-
def crop_center(h1, h2):
|
59 |
-
h1_shape = h1.size()
|
60 |
-
h2_shape = h2.size()
|
61 |
-
|
62 |
-
if h1_shape[3] == h2_shape[3]: return h1
|
63 |
-
elif h1_shape[3] < h2_shape[3]: raise ValueError("h1_shape[3] > h2_shape[3]")
|
64 |
-
|
65 |
-
s_time = (h1_shape[3] - h2_shape[3]) // 2
|
66 |
-
e_time = s_time + h2_shape[3]
|
67 |
-
|
68 |
-
h1 = h1[:, :, :, s_time:e_time]
|
69 |
-
|
70 |
-
return h1
|
71 |
-
|
72 |
-
def preprocess(X_spec):
|
73 |
-
return np.abs(X_spec), np.angle(X_spec)
|
74 |
-
|
75 |
-
def make_padding(width, cropsize, offset):
|
76 |
-
left = offset
|
77 |
-
roi_size = cropsize - offset * 2
|
78 |
-
|
79 |
-
if roi_size == 0: roi_size = cropsize
|
80 |
-
|
81 |
-
right = roi_size - (width % roi_size) + left
|
82 |
-
|
83 |
-
return left, right, roi_size
|
84 |
-
|
85 |
-
def normalize(wave, max_peak=1.0):
|
86 |
-
maxv = np.abs(wave).max()
|
87 |
-
|
88 |
-
if maxv > max_peak: wave *= max_peak / maxv
|
89 |
-
|
90 |
-
return wave
|
91 |
-
|
92 |
-
def auto_transpose(audio_array: np.ndarray):
|
93 |
-
if audio_array.shape[1] == 2: return audio_array.T
|
94 |
-
|
95 |
-
return audio_array
|
96 |
-
|
97 |
-
def write_array_to_mem(audio_data, subtype):
|
98 |
-
if isinstance(audio_data, np.ndarray):
|
99 |
-
audio_buffer = io.BytesIO()
|
100 |
-
sf.write(audio_buffer, audio_data, 44100, subtype=subtype, format="WAV")
|
101 |
-
|
102 |
-
audio_buffer.seek(0)
|
103 |
-
|
104 |
-
return audio_buffer
|
105 |
-
else: return audio_data
|
106 |
-
|
107 |
-
def spectrogram_to_image(spec, mode="magnitude"):
|
108 |
-
if mode == "magnitude":
|
109 |
-
y = np.abs(spec) if np.iscomplexobj(spec) else spec
|
110 |
-
y = np.log10(y**2 + 1e-8)
|
111 |
-
elif mode == "phase":
|
112 |
-
y = np.angle(spec) if np.iscomplexobj(spec) else spec
|
113 |
-
|
114 |
-
y -= y.min()
|
115 |
-
y *= 255 / y.max()
|
116 |
-
img = np.uint8(y)
|
117 |
-
|
118 |
-
if y.ndim == 3:
|
119 |
-
img = img.transpose(1, 2, 0)
|
120 |
-
img = np.concatenate([np.max(img, axis=2, keepdims=True), img], axis=2)
|
121 |
-
|
122 |
-
return img
|
123 |
-
|
124 |
-
def reduce_vocal_aggressively(X, y, softmask):
|
125 |
-
v = X - y
|
126 |
-
|
127 |
-
y_mag_tmp = np.abs(y)
|
128 |
-
v_mag_tmp = np.abs(v)
|
129 |
-
|
130 |
-
v_mask = v_mag_tmp > y_mag_tmp
|
131 |
-
y_mag = np.clip(y_mag_tmp - v_mag_tmp * v_mask * softmask, 0, np.inf)
|
132 |
-
|
133 |
-
return y_mag * np.exp(1.0j * np.angle(y))
|
134 |
-
|
135 |
-
def merge_artifacts(y_mask, thres=0.01, min_range=64, fade_size=32):
|
136 |
-
mask = y_mask
|
137 |
-
|
138 |
-
try:
|
139 |
-
if min_range < fade_size * 2: raise ValueError("min_range >= fade_size * 2")
|
140 |
-
|
141 |
-
idx = np.where(y_mask.min(axis=(0, 1)) > thres)[0]
|
142 |
-
start_idx = np.insert(idx[np.where(np.diff(idx) != 1)[0] + 1], 0, idx[0])
|
143 |
-
end_idx = np.append(idx[np.where(np.diff(idx) != 1)[0]], idx[-1])
|
144 |
-
artifact_idx = np.where(end_idx - start_idx > min_range)[0]
|
145 |
-
weight = np.zeros_like(y_mask)
|
146 |
-
|
147 |
-
if len(artifact_idx) > 0:
|
148 |
-
start_idx = start_idx[artifact_idx]
|
149 |
-
end_idx = end_idx[artifact_idx]
|
150 |
-
old_e = None
|
151 |
-
|
152 |
-
for s, e in zip(start_idx, end_idx):
|
153 |
-
if old_e is not None and s - old_e < fade_size: s = old_e - fade_size * 2
|
154 |
-
|
155 |
-
if s != 0: weight[:, :, s : s + fade_size] = np.linspace(0, 1, fade_size)
|
156 |
-
else: s -= fade_size
|
157 |
-
|
158 |
-
if e != y_mask.shape[2]: weight[:, :, e - fade_size : e] = np.linspace(1, 0, fade_size)
|
159 |
-
else: e += fade_size
|
160 |
-
|
161 |
-
weight[:, :, s + fade_size : e - fade_size] = 1
|
162 |
-
old_e = e
|
163 |
-
|
164 |
-
v_mask = 1 - y_mask
|
165 |
-
y_mask += weight * v_mask
|
166 |
-
mask = y_mask
|
167 |
-
except Exception as e:
|
168 |
-
error_name = f"{type(e).__name__}"
|
169 |
-
traceback_text = "".join(traceback.format_tb(e.__traceback__))
|
170 |
-
message = f'{error_name}: "{e}"\n{traceback_text}"'
|
171 |
-
print(translations["not_success"], message)
|
172 |
-
|
173 |
-
return mask
|
174 |
-
|
175 |
-
def align_wave_head_and_tail(a, b):
|
176 |
-
l = min([a[0].size, b[0].size])
|
177 |
-
|
178 |
-
return a[:l, :l], b[:l, :l]
|
179 |
-
|
180 |
-
def convert_channels(spec, mp, band):
|
181 |
-
cc = mp.param["band"][band].get("convert_channels")
|
182 |
-
|
183 |
-
if "mid_side_c" == cc:
|
184 |
-
spec_left = np.add(spec[0], spec[1] * 0.25)
|
185 |
-
spec_right = np.subtract(spec[1], spec[0] * 0.25)
|
186 |
-
elif "mid_side" == cc:
|
187 |
-
spec_left = np.add(spec[0], spec[1]) / 2
|
188 |
-
spec_right = np.subtract(spec[0], spec[1])
|
189 |
-
elif "stereo_n" == cc:
|
190 |
-
spec_left = np.add(spec[0], spec[1] * 0.25) / 0.9375
|
191 |
-
spec_right = np.add(spec[1], spec[0] * 0.25) / 0.9375
|
192 |
-
else: return spec
|
193 |
-
|
194 |
-
return np.asfortranarray([spec_left, spec_right])
|
195 |
-
|
196 |
-
def combine_spectrograms(specs, mp, is_v51_model=False):
|
197 |
-
l = min([specs[i].shape[2] for i in specs])
|
198 |
-
spec_c = np.zeros(shape=(2, mp.param["bins"] + 1, l), dtype=np.complex64)
|
199 |
-
offset = 0
|
200 |
-
bands_n = len(mp.param["band"])
|
201 |
-
|
202 |
-
for d in range(1, bands_n + 1):
|
203 |
-
h = mp.param["band"][d]["crop_stop"] - mp.param["band"][d]["crop_start"]
|
204 |
-
spec_c[:, offset : offset + h, :l] = specs[d][:, mp.param["band"][d]["crop_start"] : mp.param["band"][d]["crop_stop"], :l]
|
205 |
-
offset += h
|
206 |
-
|
207 |
-
if offset > mp.param["bins"]: raise ValueError("Quá nhiều thùng")
|
208 |
-
|
209 |
-
if mp.param["pre_filter_start"] > 0:
|
210 |
-
if is_v51_model: spec_c *= get_lp_filter_mask(spec_c.shape[1], mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
211 |
-
else:
|
212 |
-
if bands_n == 1: spec_c = fft_lp_filter(spec_c, mp.param["pre_filter_start"], mp.param["pre_filter_stop"])
|
213 |
-
else:
|
214 |
-
gp = 1
|
215 |
-
|
216 |
-
for b in range(mp.param["pre_filter_start"] + 1, mp.param["pre_filter_stop"]):
|
217 |
-
g = math.pow(10, -(b - mp.param["pre_filter_start"]) * (3.5 - gp) / 20.0)
|
218 |
-
gp = g
|
219 |
-
spec_c[:, b, :] *= g
|
220 |
-
|
221 |
-
return np.asfortranarray(spec_c)
|
222 |
-
|
223 |
-
def wave_to_spectrogram(wave, hop_length, n_fft, mp, band, is_v51_model=False):
|
224 |
-
if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
|
225 |
-
|
226 |
-
if not is_v51_model:
|
227 |
-
if mp.param["reverse"]:
|
228 |
-
wave_left = np.flip(np.asfortranarray(wave[0]))
|
229 |
-
wave_right = np.flip(np.asfortranarray(wave[1]))
|
230 |
-
elif mp.param["mid_side"]:
|
231 |
-
wave_left = np.asfortranarray(np.add(wave[0], wave[1]) / 2)
|
232 |
-
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1]))
|
233 |
-
elif mp.param["mid_side_b2"]:
|
234 |
-
wave_left = np.asfortranarray(np.add(wave[1], wave[0] * 0.5))
|
235 |
-
wave_right = np.asfortranarray(np.subtract(wave[0], wave[1] * 0.5))
|
236 |
-
else:
|
237 |
-
wave_left = np.asfortranarray(wave[0])
|
238 |
-
wave_right = np.asfortranarray(wave[1])
|
239 |
-
else:
|
240 |
-
wave_left = np.asfortranarray(wave[0])
|
241 |
-
wave_right = np.asfortranarray(wave[1])
|
242 |
-
|
243 |
-
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
244 |
-
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
245 |
-
|
246 |
-
spec = np.asfortranarray([spec_left, spec_right])
|
247 |
-
|
248 |
-
if is_v51_model: spec = convert_channels(spec, mp, band)
|
249 |
-
|
250 |
-
return spec
|
251 |
-
|
252 |
-
def spectrogram_to_wave(spec, hop_length=1024, mp={}, band=0, is_v51_model=True):
|
253 |
-
spec_left = np.asfortranarray(spec[0])
|
254 |
-
spec_right = np.asfortranarray(spec[1])
|
255 |
-
|
256 |
-
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
257 |
-
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
258 |
-
|
259 |
-
if is_v51_model:
|
260 |
-
cc = mp.param["band"][band].get("convert_channels")
|
261 |
-
|
262 |
-
if "mid_side_c" == cc: return np.asfortranarray([np.subtract(wave_left / 1.0625, wave_right / 4.25), np.add(wave_right / 1.0625, wave_left / 4.25)])
|
263 |
-
elif "mid_side" == cc: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
264 |
-
elif "stereo_n" == cc: return np.asfortranarray([np.subtract(wave_left, wave_right * 0.25), np.subtract(wave_right, wave_left * 0.25)])
|
265 |
-
else:
|
266 |
-
if mp.param["reverse"]: return np.asfortranarray([np.flip(wave_left), np.flip(wave_right)])
|
267 |
-
elif mp.param["mid_side"]: return np.asfortranarray([np.add(wave_left, wave_right / 2), np.subtract(wave_left, wave_right / 2)])
|
268 |
-
elif mp.param["mid_side_b2"]: return np.asfortranarray([np.add(wave_right / 1.25, 0.4 * wave_left), np.subtract(wave_left / 1.25, 0.4 * wave_right)])
|
269 |
-
|
270 |
-
return np.asfortranarray([wave_left, wave_right])
|
271 |
-
|
272 |
-
def cmb_spectrogram_to_wave(spec_m, mp, extra_bins_h=None, extra_bins=None, is_v51_model=False):
|
273 |
-
bands_n = len(mp.param["band"])
|
274 |
-
offset = 0
|
275 |
-
|
276 |
-
for d in range(1, bands_n + 1):
|
277 |
-
bp = mp.param["band"][d]
|
278 |
-
spec_s = np.zeros(shape=(2, bp["n_fft"] // 2 + 1, spec_m.shape[2]), dtype=complex)
|
279 |
-
h = bp["crop_stop"] - bp["crop_start"]
|
280 |
-
spec_s[:, bp["crop_start"] : bp["crop_stop"], :] = spec_m[:, offset : offset + h, :]
|
281 |
-
|
282 |
-
offset += h
|
283 |
-
|
284 |
-
if d == bands_n:
|
285 |
-
if extra_bins_h:
|
286 |
-
max_bin = bp["n_fft"] // 2
|
287 |
-
spec_s[:, max_bin - extra_bins_h : max_bin, :] = extra_bins[:, :extra_bins_h, :]
|
288 |
-
|
289 |
-
if bp["hpf_start"] > 0:
|
290 |
-
if is_v51_model: spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
291 |
-
else: spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
292 |
-
|
293 |
-
if bands_n == 1: wave = spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model)
|
294 |
-
else: wave = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
295 |
-
else:
|
296 |
-
sr = mp.param["band"][d + 1]["sr"]
|
297 |
-
if d == 1:
|
298 |
-
if is_v51_model: spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
299 |
-
else: spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
300 |
-
|
301 |
-
try:
|
302 |
-
wave = librosa.resample(spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model), orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
303 |
-
except ValueError as e:
|
304 |
-
print(f"{translations['resample_error']}: {e}")
|
305 |
-
print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
|
306 |
-
else:
|
307 |
-
if is_v51_model:
|
308 |
-
spec_s *= get_hp_filter_mask(spec_s.shape[1], bp["hpf_start"], bp["hpf_stop"] - 1)
|
309 |
-
spec_s *= get_lp_filter_mask(spec_s.shape[1], bp["lpf_start"], bp["lpf_stop"])
|
310 |
-
else:
|
311 |
-
spec_s = fft_hp_filter(spec_s, bp["hpf_start"], bp["hpf_stop"] - 1)
|
312 |
-
spec_s = fft_lp_filter(spec_s, bp["lpf_start"], bp["lpf_stop"])
|
313 |
-
|
314 |
-
wave2 = np.add(wave, spectrogram_to_wave(spec_s, bp["hl"], mp, d, is_v51_model))
|
315 |
-
|
316 |
-
try:
|
317 |
-
wave = librosa.resample(wave2, orig_sr=bp["sr"], target_sr=sr, res_type=wav_resolution)
|
318 |
-
except ValueError as e:
|
319 |
-
print(f"{translations['resample_error']}: {e}")
|
320 |
-
print(f"{translations['shapes']} Spec_s: {spec_s.shape}, SR: {sr}, {translations['wav_resolution']}: {wav_resolution}")
|
321 |
-
|
322 |
-
return wave
|
323 |
-
|
324 |
-
def get_lp_filter_mask(n_bins, bin_start, bin_stop):
|
325 |
-
return np.concatenate([np.ones((bin_start - 1, 1)), np.linspace(1, 0, bin_stop - bin_start + 1)[:, None], np.zeros((n_bins - bin_stop, 1))], axis=0)
|
326 |
-
|
327 |
-
def get_hp_filter_mask(n_bins, bin_start, bin_stop):
|
328 |
-
return np.concatenate([np.zeros((bin_stop + 1, 1)), np.linspace(0, 1, 1 + bin_start - bin_stop)[:, None], np.ones((n_bins - bin_start - 2, 1))], axis=0)
|
329 |
-
|
330 |
-
def fft_lp_filter(spec, bin_start, bin_stop):
|
331 |
-
g = 1.0
|
332 |
-
|
333 |
-
for b in range(bin_start, bin_stop):
|
334 |
-
g -= 1 / (bin_stop - bin_start)
|
335 |
-
spec[:, b, :] = g * spec[:, b, :]
|
336 |
-
|
337 |
-
spec[:, bin_stop:, :] *= 0
|
338 |
-
|
339 |
-
return spec
|
340 |
-
|
341 |
-
def fft_hp_filter(spec, bin_start, bin_stop):
|
342 |
-
g = 1.0
|
343 |
-
|
344 |
-
for b in range(bin_start, bin_stop, -1):
|
345 |
-
g -= 1 / (bin_start - bin_stop)
|
346 |
-
spec[:, b, :] = g * spec[:, b, :]
|
347 |
-
|
348 |
-
spec[:, 0 : bin_stop + 1, :] *= 0
|
349 |
-
|
350 |
-
return spec
|
351 |
-
|
352 |
-
def spectrogram_to_wave_old(spec, hop_length=1024):
|
353 |
-
if spec.ndim == 2: wave = librosa.istft(spec, hop_length=hop_length)
|
354 |
-
elif spec.ndim == 3:
|
355 |
-
spec_left = np.asfortranarray(spec[0])
|
356 |
-
spec_right = np.asfortranarray(spec[1])
|
357 |
-
|
358 |
-
wave_left = librosa.istft(spec_left, hop_length=hop_length)
|
359 |
-
wave_right = librosa.istft(spec_right, hop_length=hop_length)
|
360 |
-
wave = np.asfortranarray([wave_left, wave_right])
|
361 |
-
|
362 |
-
return wave
|
363 |
-
|
364 |
-
def wave_to_spectrogram_old(wave, hop_length, n_fft):
|
365 |
-
wave_left = np.asfortranarray(wave[0])
|
366 |
-
wave_right = np.asfortranarray(wave[1])
|
367 |
-
spec_left = librosa.stft(wave_left, n_fft=n_fft, hop_length=hop_length)
|
368 |
-
spec_right = librosa.stft(wave_right, n_fft=n_fft, hop_length=hop_length)
|
369 |
-
|
370 |
-
return np.asfortranarray([spec_left, spec_right])
|
371 |
-
|
372 |
-
def mirroring(a, spec_m, input_high_end, mp):
|
373 |
-
if "mirroring" == a:
|
374 |
-
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
375 |
-
mirror = mirror * np.exp(1.0j * np.angle(input_high_end))
|
376 |
-
|
377 |
-
return np.where(np.abs(input_high_end) <= np.abs(mirror), input_high_end, mirror)
|
378 |
-
|
379 |
-
if "mirroring2" == a:
|
380 |
-
mirror = np.flip(np.abs(spec_m[:, mp.param["pre_filter_start"] - 10 - input_high_end.shape[1] : mp.param["pre_filter_start"] - 10, :]), 1)
|
381 |
-
mi = np.multiply(mirror, input_high_end * 1.7)
|
382 |
-
|
383 |
-
return np.where(np.abs(input_high_end) <= np.abs(mi), input_high_end, mi)
|
384 |
-
|
385 |
-
def adjust_aggr(mask, is_non_accom_stem, aggressiveness):
|
386 |
-
aggr = aggressiveness["value"] * 2
|
387 |
-
|
388 |
-
if aggr != 0:
|
389 |
-
if is_non_accom_stem:
|
390 |
-
aggr = 1 - aggr
|
391 |
-
|
392 |
-
if np.any(aggr > 10) or np.any(aggr < -10): print(f"{translations['warnings']}: {aggr}")
|
393 |
-
|
394 |
-
aggr = [aggr, aggr]
|
395 |
-
|
396 |
-
if aggressiveness["aggr_correction"] is not None:
|
397 |
-
aggr[0] += aggressiveness["aggr_correction"]["left"]
|
398 |
-
aggr[1] += aggressiveness["aggr_correction"]["right"]
|
399 |
-
|
400 |
-
for ch in range(2):
|
401 |
-
mask[ch, : aggressiveness["split_bin"]] = np.power(mask[ch, : aggressiveness["split_bin"]], 1 + aggr[ch] / 3)
|
402 |
-
mask[ch, aggressiveness["split_bin"] :] = np.power(mask[ch, aggressiveness["split_bin"] :], 1 + aggr[ch])
|
403 |
-
|
404 |
-
return mask
|
405 |
-
|
406 |
-
def stft(wave, nfft, hl):
|
407 |
-
wave_left = np.asfortranarray(wave[0])
|
408 |
-
wave_right = np.asfortranarray(wave[1])
|
409 |
-
spec_left = librosa.stft(wave_left, n_fft=nfft, hop_length=hl)
|
410 |
-
spec_right = librosa.stft(wave_right, n_fft=nfft, hop_length=hl)
|
411 |
-
spec = np.asfortranarray([spec_left, spec_right])
|
412 |
-
|
413 |
-
return spec
|
414 |
-
|
415 |
-
def istft(spec, hl):
|
416 |
-
spec_left = np.asfortranarray(spec[0])
|
417 |
-
spec_right = np.asfortranarray(spec[1])
|
418 |
-
wave_left = librosa.istft(spec_left, hop_length=hl)
|
419 |
-
wave_right = librosa.istft(spec_right, hop_length=hl)
|
420 |
-
wave = np.asfortranarray([wave_left, wave_right])
|
421 |
-
|
422 |
-
return wave
|
423 |
-
|
424 |
-
def spec_effects(wave, algorithm="Default", value=None):
|
425 |
-
if np.isnan(wave).any() or np.isinf(wave).any(): print(f"{translations['warnings_2']}: {wave.shape}")
|
426 |
-
|
427 |
-
spec = [stft(wave[0], 2048, 1024), stft(wave[1], 2048, 1024)]
|
428 |
-
|
429 |
-
if algorithm == "Min_Mag":
|
430 |
-
v_spec_m = np.where(np.abs(spec[1]) <= np.abs(spec[0]), spec[1], spec[0])
|
431 |
-
wave = istft(v_spec_m, 1024)
|
432 |
-
elif algorithm == "Max_Mag":
|
433 |
-
v_spec_m = np.where(np.abs(spec[1]) >= np.abs(spec[0]), spec[1], spec[0])
|
434 |
-
wave = istft(v_spec_m, 1024)
|
435 |
-
elif algorithm == "Default": wave = (wave[1] * value) + (wave[0] * (1 - value))
|
436 |
-
elif algorithm == "Invert_p":
|
437 |
-
X_mag = np.abs(spec[0])
|
438 |
-
y_mag = np.abs(spec[1])
|
439 |
-
|
440 |
-
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
441 |
-
v_spec = spec[1] - max_mag * np.exp(1.0j * np.angle(spec[0]))
|
442 |
-
|
443 |
-
wave = istft(v_spec, 1024)
|
444 |
-
|
445 |
-
return wave
|
446 |
-
|
447 |
-
def spectrogram_to_wave_no_mp(spec, n_fft=2048, hop_length=1024):
|
448 |
-
wave = librosa.istft(spec, n_fft=n_fft, hop_length=hop_length)
|
449 |
-
if wave.ndim == 1: wave = np.asfortranarray([wave, wave])
|
450 |
-
|
451 |
-
return wave
|
452 |
-
|
453 |
-
def wave_to_spectrogram_no_mp(wave):
|
454 |
-
|
455 |
-
spec = librosa.stft(wave, n_fft=2048, hop_length=1024)
|
456 |
-
|
457 |
-
if spec.ndim == 1: spec = np.asfortranarray([spec, spec])
|
458 |
-
|
459 |
-
return spec
|
460 |
-
|
461 |
-
def invert_audio(specs, invert_p=True):
|
462 |
-
|
463 |
-
ln = min([specs[0].shape[2], specs[1].shape[2]])
|
464 |
-
specs[0] = specs[0][:, :, :ln]
|
465 |
-
specs[1] = specs[1][:, :, :ln]
|
466 |
-
|
467 |
-
if invert_p:
|
468 |
-
X_mag = np.abs(specs[0])
|
469 |
-
y_mag = np.abs(specs[1])
|
470 |
-
max_mag = np.where(X_mag >= y_mag, X_mag, y_mag)
|
471 |
-
v_spec = specs[1] - max_mag * np.exp(1.0j * np.angle(specs[0]))
|
472 |
-
else:
|
473 |
-
specs[1] = reduce_vocal_aggressively(specs[0], specs[1], 0.2)
|
474 |
-
v_spec = specs[0] - specs[1]
|
475 |
-
|
476 |
-
return v_spec
|
477 |
-
|
478 |
-
def invert_stem(mixture, stem):
|
479 |
-
mixture = wave_to_spectrogram_no_mp(mixture)
|
480 |
-
stem = wave_to_spectrogram_no_mp(stem)
|
481 |
-
output = spectrogram_to_wave_no_mp(invert_audio([mixture, stem]))
|
482 |
-
|
483 |
-
return -output.T
|
484 |
-
|
485 |
-
def ensembling(a, inputs, is_wavs=False):
|
486 |
-
for i in range(1, len(inputs)):
|
487 |
-
if i == 1: input = inputs[0]
|
488 |
-
|
489 |
-
if is_wavs:
|
490 |
-
ln = min([input.shape[1], inputs[i].shape[1]])
|
491 |
-
input = input[:, :ln]
|
492 |
-
inputs[i] = inputs[i][:, :ln]
|
493 |
-
else:
|
494 |
-
ln = min([input.shape[2], inputs[i].shape[2]])
|
495 |
-
input = input[:, :, :ln]
|
496 |
-
inputs[i] = inputs[i][:, :, :ln]
|
497 |
-
|
498 |
-
if MIN_SPEC == a: input = np.where(np.abs(inputs[i]) <= np.abs(input), inputs[i], input)
|
499 |
-
if MAX_SPEC == a: input = np.where(np.abs(inputs[i]) >= np.abs(input), inputs[i], input)
|
500 |
-
|
501 |
-
return input
|
502 |
-
|
503 |
-
def ensemble_for_align(waves):
|
504 |
-
|
505 |
-
specs = []
|
506 |
-
|
507 |
-
for wav in waves:
|
508 |
-
spec = wave_to_spectrogram_no_mp(wav.T)
|
509 |
-
specs.append(spec)
|
510 |
-
|
511 |
-
wav_aligned = spectrogram_to_wave_no_mp(ensembling(MIN_SPEC, specs)).T
|
512 |
-
wav_aligned = match_array_shapes(wav_aligned, waves[1], is_swap=True)
|
513 |
-
|
514 |
-
return wav_aligned
|
515 |
-
|
516 |
-
def ensemble_inputs(audio_input, algorithm, is_normalization, wav_type_set, save_path, is_wave=False, is_array=False):
|
517 |
-
wavs_ = []
|
518 |
-
|
519 |
-
if algorithm == AVERAGE:
|
520 |
-
output = average_audio(audio_input)
|
521 |
-
samplerate = 44100
|
522 |
-
else:
|
523 |
-
specs = []
|
524 |
-
|
525 |
-
for i in range(len(audio_input)):
|
526 |
-
wave, samplerate = librosa.load(audio_input[i], mono=False, sr=44100)
|
527 |
-
wavs_.append(wave)
|
528 |
-
spec = wave if is_wave else wave_to_spectrogram_no_mp(wave)
|
529 |
-
specs.append(spec)
|
530 |
-
|
531 |
-
wave_shapes = [w.shape[1] for w in wavs_]
|
532 |
-
target_shape = wavs_[wave_shapes.index(max(wave_shapes))]
|
533 |
-
|
534 |
-
output = ensembling(algorithm, specs, is_wavs=True) if is_wave else spectrogram_to_wave_no_mp(ensembling(algorithm, specs))
|
535 |
-
output = to_shape(output, target_shape.shape)
|
536 |
-
|
537 |
-
sf.write(save_path, normalize(output.T, is_normalization), samplerate, subtype=wav_type_set)
|
538 |
-
|
539 |
-
def to_shape(x, target_shape):
|
540 |
-
padding_list = []
|
541 |
-
|
542 |
-
for x_dim, target_dim in zip(x.shape, target_shape):
|
543 |
-
pad_value = target_dim - x_dim
|
544 |
-
pad_tuple = (0, pad_value)
|
545 |
-
padding_list.append(pad_tuple)
|
546 |
-
|
547 |
-
return np.pad(x, tuple(padding_list), mode="constant")
|
548 |
-
|
549 |
-
def to_shape_minimize(x: np.ndarray, target_shape):
|
550 |
-
padding_list = []
|
551 |
-
|
552 |
-
for x_dim, target_dim in zip(x.shape, target_shape):
|
553 |
-
pad_value = target_dim - x_dim
|
554 |
-
pad_tuple = (0, pad_value)
|
555 |
-
padding_list.append(pad_tuple)
|
556 |
-
|
557 |
-
return np.pad(x, tuple(padding_list), mode="constant")
|
558 |
-
|
559 |
-
def detect_leading_silence(audio, sr, silence_threshold=0.007, frame_length=1024):
|
560 |
-
if len(audio.shape) == 2:
|
561 |
-
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
562 |
-
audio = audio[channel]
|
563 |
-
|
564 |
-
for i in range(0, len(audio), frame_length):
|
565 |
-
if np.max(np.abs(audio[i : i + frame_length])) > silence_threshold: return (i / sr) * 1000
|
566 |
-
|
567 |
-
return (len(audio) / sr) * 1000
|
568 |
-
|
569 |
-
def adjust_leading_silence(target_audio, reference_audio, silence_threshold=0.01, frame_length=1024):
|
570 |
-
def find_silence_end(audio):
|
571 |
-
if len(audio.shape) == 2:
|
572 |
-
channel = np.argmax(np.sum(np.abs(audio), axis=1))
|
573 |
-
audio_mono = audio[channel]
|
574 |
-
else: audio_mono = audio
|
575 |
-
|
576 |
-
for i in range(0, len(audio_mono), frame_length):
|
577 |
-
if np.max(np.abs(audio_mono[i : i + frame_length])) > silence_threshold: return i
|
578 |
-
|
579 |
-
return len(audio_mono)
|
580 |
-
|
581 |
-
ref_silence_end = find_silence_end(reference_audio)
|
582 |
-
target_silence_end = find_silence_end(target_audio)
|
583 |
-
silence_difference = ref_silence_end - target_silence_end
|
584 |
-
|
585 |
-
try:
|
586 |
-
ref_silence_end_p = (ref_silence_end / 44100) * 1000
|
587 |
-
target_silence_end_p = (target_silence_end / 44100) * 1000
|
588 |
-
silence_difference_p = ref_silence_end_p - target_silence_end_p
|
589 |
-
|
590 |
-
print("im lặng khác biệt: ", silence_difference_p)
|
591 |
-
except Exception as e:
|
592 |
-
pass
|
593 |
-
|
594 |
-
if silence_difference > 0:
|
595 |
-
silence_to_add = np.zeros((target_audio.shape[0], silence_difference))if len(target_audio.shape) == 2 else np.zeros(silence_difference)
|
596 |
-
|
597 |
-
return np.hstack((silence_to_add, target_audio))
|
598 |
-
elif silence_difference < 0:
|
599 |
-
if len(target_audio.shape) == 2: return target_audio[:, -silence_difference:]
|
600 |
-
else: return target_audio[-silence_difference:]
|
601 |
-
else: return target_audio
|
602 |
-
|
603 |
-
def match_array_shapes(array_1: np.ndarray, array_2: np.ndarray, is_swap=False):
|
604 |
-
|
605 |
-
if is_swap: array_1, array_2 = array_1.T, array_2.T
|
606 |
-
|
607 |
-
if array_1.shape[1] > array_2.shape[1]: array_1 = array_1[:, : array_2.shape[1]]
|
608 |
-
elif array_1.shape[1] < array_2.shape[1]:
|
609 |
-
padding = array_2.shape[1] - array_1.shape[1]
|
610 |
-
array_1 = np.pad(array_1, ((0, 0), (0, padding)), "constant", constant_values=0)
|
611 |
-
|
612 |
-
if is_swap: array_1, array_2 = array_1.T, array_2.T
|
613 |
-
|
614 |
-
return array_1
|
615 |
-
|
616 |
-
def match_mono_array_shapes(array_1: np.ndarray, array_2: np.ndarray):
|
617 |
-
if len(array_1) > len(array_2): array_1 = array_1[: len(array_2)]
|
618 |
-
elif len(array_1) < len(array_2):
|
619 |
-
padding = len(array_2) - len(array_1)
|
620 |
-
array_1 = np.pad(array_1, (0, padding), "constant", constant_values=0)
|
621 |
-
|
622 |
-
return array_1
|
623 |
-
|
624 |
-
def change_pitch_semitones(y, sr, semitone_shift):
|
625 |
-
factor = 2 ** (semitone_shift / 12)
|
626 |
-
y_pitch_tuned = []
|
627 |
-
|
628 |
-
for y_channel in y:
|
629 |
-
y_pitch_tuned.append(librosa.resample(y_channel, orig_sr=sr, target_sr=sr * factor, res_type=wav_resolution_float_resampling))
|
630 |
-
|
631 |
-
y_pitch_tuned = np.array(y_pitch_tuned)
|
632 |
-
new_sr = sr * factor
|
633 |
-
|
634 |
-
return y_pitch_tuned, new_sr
|
635 |
-
|
636 |
-
def augment_audio(export_path, audio_file, rate, is_normalization, wav_type_set, save_format=None, is_pitch=False, is_time_correction=True):
|
637 |
-
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
638 |
-
|
639 |
-
if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
|
640 |
-
|
641 |
-
if not is_time_correction: wav_mix = change_pitch_semitones(wav, 44100, semitone_shift=-rate)[0]
|
642 |
-
else:
|
643 |
-
if is_pitch:
|
644 |
-
wav_1 = pitch_shift(wav[0], sr, rate, rbargs=None)
|
645 |
-
wav_2 = pitch_shift(wav[1], sr, rate, rbargs=None)
|
646 |
-
else:
|
647 |
-
wav_1 = time_stretch(wav[0], sr, rate, rbargs=None)
|
648 |
-
wav_2 = time_stretch(wav[1], sr, rate, rbargs=None)
|
649 |
-
|
650 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
651 |
-
if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
|
652 |
-
|
653 |
-
wav_mix = np.asfortranarray([wav_1, wav_2])
|
654 |
-
|
655 |
-
sf.write(export_path, normalize(wav_mix.T, is_normalization), sr, subtype=wav_type_set)
|
656 |
-
save_format(export_path)
|
657 |
-
|
658 |
-
|
659 |
-
def average_audio(audio):
|
660 |
-
waves = []
|
661 |
-
wave_shapes = []
|
662 |
-
final_waves = []
|
663 |
-
|
664 |
-
for i in range(len(audio)):
|
665 |
-
wave = librosa.load(audio[i], sr=44100, mono=False)
|
666 |
-
waves.append(wave[0])
|
667 |
-
wave_shapes.append(wave[0].shape[1])
|
668 |
-
|
669 |
-
wave_shapes_index = wave_shapes.index(max(wave_shapes))
|
670 |
-
target_shape = waves[wave_shapes_index]
|
671 |
-
|
672 |
-
waves.pop(wave_shapes_index)
|
673 |
-
final_waves.append(target_shape)
|
674 |
-
|
675 |
-
for n_array in waves:
|
676 |
-
wav_target = to_shape(n_array, target_shape.shape)
|
677 |
-
final_waves.append(wav_target)
|
678 |
-
|
679 |
-
waves = sum(final_waves)
|
680 |
-
waves = waves / len(audio)
|
681 |
-
|
682 |
-
return waves
|
683 |
-
|
684 |
-
def average_dual_sources(wav_1, wav_2, value):
|
685 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
686 |
-
if wav_1.shape < wav_2.shape: wav_1 = to_shape(wav_1, wav_2.shape)
|
687 |
-
|
688 |
-
wave = (wav_1 * value) + (wav_2 * (1 - value))
|
689 |
-
|
690 |
-
return wave
|
691 |
-
|
692 |
-
def reshape_sources(wav_1: np.ndarray, wav_2: np.ndarray):
|
693 |
-
if wav_1.shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1.shape)
|
694 |
-
|
695 |
-
if wav_1.shape < wav_2.shape:
|
696 |
-
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
697 |
-
wav_2 = wav_2[:, :ln]
|
698 |
-
|
699 |
-
ln = min([wav_1.shape[1], wav_2.shape[1]])
|
700 |
-
wav_1 = wav_1[:, :ln]
|
701 |
-
wav_2 = wav_2[:, :ln]
|
702 |
-
|
703 |
-
return wav_2
|
704 |
-
|
705 |
-
def reshape_sources_ref(wav_1_shape, wav_2: np.ndarray):
|
706 |
-
if wav_1_shape > wav_2.shape: wav_2 = to_shape(wav_2, wav_1_shape)
|
707 |
-
|
708 |
-
return wav_2
|
709 |
-
|
710 |
-
def combine_arrarys(audio_sources, is_swap=False):
|
711 |
-
source = np.zeros_like(max(audio_sources, key=np.size))
|
712 |
-
|
713 |
-
for v in audio_sources:
|
714 |
-
v = match_array_shapes(v, source, is_swap=is_swap)
|
715 |
-
source += v
|
716 |
-
|
717 |
-
return source
|
718 |
-
|
719 |
-
def combine_audio(paths: list, audio_file_base=None, wav_type_set="FLOAT", save_format=None):
|
720 |
-
source = combine_arrarys([load_audio(i) for i in paths])
|
721 |
-
save_path = f"{audio_file_base}_combined.wav"
|
722 |
-
|
723 |
-
sf.write(save_path, source.T, 44100, subtype=wav_type_set)
|
724 |
-
save_format(save_path)
|
725 |
-
|
726 |
-
def reduce_mix_bv(inst_source, voc_source, reduction_rate=0.9):
|
727 |
-
inst_source = inst_source * (1 - reduction_rate)
|
728 |
-
mix_reduced = combine_arrarys([inst_source, voc_source], is_swap=True)
|
729 |
-
|
730 |
-
return mix_reduced
|
731 |
-
|
732 |
-
def organize_inputs(inputs):
|
733 |
-
input_list = {"target": None, "reference": None, "reverb": None, "inst": None}
|
734 |
-
|
735 |
-
for i in inputs:
|
736 |
-
if i.endswith("_(Vocals).wav"): input_list["reference"] = i
|
737 |
-
elif "_RVC_" in i: input_list["target"] = i
|
738 |
-
elif i.endswith("reverbed_stem.wav"): input_list["reverb"] = i
|
739 |
-
elif i.endswith("_(Instrumental).wav"): input_list["inst"] = i
|
740 |
-
|
741 |
-
return input_list
|
742 |
-
|
743 |
-
def check_if_phase_inverted(wav1, wav2, is_mono=False):
|
744 |
-
if not is_mono:
|
745 |
-
wav1 = np.mean(wav1, axis=0)
|
746 |
-
wav2 = np.mean(wav2, axis=0)
|
747 |
-
|
748 |
-
correlation = np.corrcoef(wav1[:1000], wav2[:1000])
|
749 |
-
|
750 |
-
return correlation[0, 1] < 0
|
751 |
-
|
752 |
-
def align_audio(file1, file2, file2_aligned, file_subtracted, wav_type_set, is_save_aligned, command_Text, save_format, align_window: list, align_intro_val: list, db_analysis: tuple, set_progress_bar, phase_option, phase_shifts, is_match_silence, is_spec_match):
|
753 |
-
global progress_value
|
754 |
-
progress_value = 0
|
755 |
-
is_mono = False
|
756 |
-
|
757 |
-
def get_diff(a, b):
|
758 |
-
corr = np.correlate(a, b, "full")
|
759 |
-
diff = corr.argmax() - (b.shape[0] - 1)
|
760 |
-
|
761 |
-
return diff
|
762 |
-
|
763 |
-
def progress_bar(length):
|
764 |
-
global progress_value
|
765 |
-
|
766 |
-
progress_value += 1
|
767 |
-
|
768 |
-
if (0.90 / length * progress_value) >= 0.9: length = progress_value + 1
|
769 |
-
|
770 |
-
set_progress_bar(0.1, (0.9 / length * progress_value))
|
771 |
-
|
772 |
-
if file1.endswith(".mp3") and is_macos:
|
773 |
-
length1 = rerun_mp3(file1)
|
774 |
-
wav1, sr1 = librosa.load(file1, duration=length1, sr=44100, mono=False)
|
775 |
-
else: wav1, sr1 = librosa.load(file1, sr=44100, mono=False)
|
776 |
-
|
777 |
-
if file2.endswith(".mp3") and is_macos:
|
778 |
-
length2 = rerun_mp3(file2)
|
779 |
-
wav2, sr2 = librosa.load(file2, duration=length2, sr=44100, mono=False)
|
780 |
-
else: wav2, sr2 = librosa.load(file2, sr=44100, mono=False)
|
781 |
-
|
782 |
-
if wav1.ndim == 1 and wav2.ndim == 1: is_mono = True
|
783 |
-
elif wav1.ndim == 1: wav1 = np.asfortranarray([wav1, wav1])
|
784 |
-
elif wav2.ndim == 1: wav2 = np.asfortranarray([wav2, wav2])
|
785 |
-
|
786 |
-
if phase_option == AUTO_PHASE:
|
787 |
-
if check_if_phase_inverted(wav1, wav2, is_mono=is_mono): wav2 = -wav2
|
788 |
-
elif phase_option == POSITIVE_PHASE: wav2 = +wav2
|
789 |
-
elif phase_option == NEGATIVE_PHASE: wav2 = -wav2
|
790 |
-
|
791 |
-
if is_match_silence: wav2 = adjust_leading_silence(wav2, wav1)
|
792 |
-
|
793 |
-
wav1_length = int(librosa.get_duration(y=wav1, sr=44100))
|
794 |
-
wav2_length = int(librosa.get_duration(y=wav2, sr=44100))
|
795 |
-
|
796 |
-
if not is_mono:
|
797 |
-
wav1 = wav1.transpose()
|
798 |
-
wav2 = wav2.transpose()
|
799 |
-
|
800 |
-
wav2_org = wav2.copy()
|
801 |
-
|
802 |
-
command_Text(translations["process_file"])
|
803 |
-
seconds_length = min(wav1_length, wav2_length)
|
804 |
-
|
805 |
-
wav2_aligned_sources = []
|
806 |
-
|
807 |
-
for sec_len in align_intro_val:
|
808 |
-
sec_seg = 1 if sec_len == 1 else int(seconds_length // sec_len)
|
809 |
-
index = sr1 * sec_seg
|
810 |
-
|
811 |
-
if is_mono:
|
812 |
-
samp1, samp2 = wav1[index : index + sr1], wav2[index : index + sr1]
|
813 |
-
diff = get_diff(samp1, samp2)
|
814 |
-
else:
|
815 |
-
index = sr1 * sec_seg
|
816 |
-
samp1, samp2 = wav1[index : index + sr1, 0], wav2[index : index + sr1, 0]
|
817 |
-
samp1_r, samp2_r = wav1[index : index + sr1, 1], wav2[index : index + sr1, 1]
|
818 |
-
diff, diff_r = get_diff(samp1, samp2), get_diff(samp1_r, samp2_r)
|
819 |
-
|
820 |
-
if diff > 0:
|
821 |
-
zeros_to_append = np.zeros(diff) if is_mono else np.zeros((diff, 2))
|
822 |
-
wav2_aligned = np.append(zeros_to_append, wav2_org, axis=0)
|
823 |
-
elif diff < 0: wav2_aligned = wav2_org[-diff:]
|
824 |
-
else: wav2_aligned = wav2_org
|
825 |
-
|
826 |
-
if not any(np.array_equal(wav2_aligned, source) for source in wav2_aligned_sources): wav2_aligned_sources.append(wav2_aligned)
|
827 |
-
|
828 |
-
unique_sources = len(wav2_aligned_sources)
|
829 |
-
|
830 |
-
sub_mapper_big_mapper = {}
|
831 |
-
|
832 |
-
for s in wav2_aligned_sources:
|
833 |
-
wav2_aligned = match_mono_array_shapes(s, wav1) if is_mono else match_array_shapes(s, wav1, is_swap=True)
|
834 |
-
|
835 |
-
if align_window:
|
836 |
-
wav_sub = time_correction(wav1, wav2_aligned, seconds_length, align_window=align_window, db_analysis=db_analysis, progress_bar=progress_bar, unique_sources=unique_sources, phase_shifts=phase_shifts)
|
837 |
-
wav_sub_size = np.abs(wav_sub).mean()
|
838 |
-
|
839 |
-
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
840 |
-
else:
|
841 |
-
wav2_aligned = wav2_aligned * np.power(10, db_analysis[0] / 20)
|
842 |
-
db_range = db_analysis[1]
|
843 |
-
|
844 |
-
for db_adjustment in db_range:
|
845 |
-
s_adjusted = wav2_aligned * (10 ** (db_adjustment / 20))
|
846 |
-
|
847 |
-
wav_sub = wav1 - s_adjusted
|
848 |
-
wav_sub_size = np.abs(wav_sub).mean()
|
849 |
-
|
850 |
-
sub_mapper_big_mapper = {**sub_mapper_big_mapper, **{wav_sub_size: wav_sub}}
|
851 |
-
|
852 |
-
sub_mapper_value_list = list(sub_mapper_big_mapper.values())
|
853 |
-
|
854 |
-
wav_sub = ensemble_for_align(list(sub_mapper_big_mapper.values())) if is_spec_match and len(sub_mapper_value_list) >= 2 else ensemble_wav(list(sub_mapper_big_mapper.values()))
|
855 |
-
|
856 |
-
wav_sub = np.clip(wav_sub, -1, +1)
|
857 |
-
|
858 |
-
command_Text(translations["save_instruments"])
|
859 |
-
|
860 |
-
if is_save_aligned or is_spec_match:
|
861 |
-
wav1 = match_mono_array_shapes(wav1, wav_sub) if is_mono else match_array_shapes(wav1, wav_sub, is_swap=True)
|
862 |
-
wav2_aligned = wav1 - wav_sub
|
863 |
-
|
864 |
-
if is_spec_match:
|
865 |
-
if wav1.ndim == 1 and wav2.ndim == 1:
|
866 |
-
wav2_aligned = np.asfortranarray([wav2_aligned, wav2_aligned]).T
|
867 |
-
wav1 = np.asfortranarray([wav1, wav1]).T
|
868 |
-
|
869 |
-
wav2_aligned = ensemble_for_align([wav2_aligned, wav1])
|
870 |
-
wav_sub = wav1 - wav2_aligned
|
871 |
-
|
872 |
-
if is_save_aligned:
|
873 |
-
sf.write(file2_aligned, wav2_aligned, sr1, subtype=wav_type_set)
|
874 |
-
save_format(file2_aligned)
|
875 |
-
|
876 |
-
sf.write(file_subtracted, wav_sub, sr1, subtype=wav_type_set)
|
877 |
-
save_format(file_subtracted)
|
878 |
-
|
879 |
-
def phase_shift_hilbert(signal, degree):
|
880 |
-
analytic_signal = hilbert(signal)
|
881 |
-
|
882 |
-
return np.cos(np.radians(degree)) * analytic_signal.real - np.sin(np.radians(degree)) * analytic_signal.imag
|
883 |
-
|
884 |
-
def get_phase_shifted_tracks(track, phase_shift):
|
885 |
-
if phase_shift == 180: return [track, -track]
|
886 |
-
|
887 |
-
step = phase_shift
|
888 |
-
end = 180 - (180 % step) if 180 % step == 0 else 181
|
889 |
-
phase_range = range(step, end, step)
|
890 |
-
|
891 |
-
flipped_list = [track, -track]
|
892 |
-
|
893 |
-
for i in phase_range:
|
894 |
-
flipped_list.extend([phase_shift_hilbert(track, i), phase_shift_hilbert(track, -i)])
|
895 |
-
|
896 |
-
return flipped_list
|
897 |
-
|
898 |
-
def time_correction(mix: np.ndarray, instrumental: np.ndarray, seconds_length, align_window, db_analysis, sr=44100, progress_bar=None, unique_sources=None, phase_shifts=NONE_P):
|
899 |
-
def align_tracks(track1, track2):
|
900 |
-
shifted_tracks = {}
|
901 |
-
|
902 |
-
track2 = track2 * np.power(10, db_analysis[0] / 20)
|
903 |
-
db_range = db_analysis[1]
|
904 |
-
|
905 |
-
track2_flipped = [track2] if phase_shifts == 190 else get_phase_shifted_tracks(track2, phase_shifts)
|
906 |
-
|
907 |
-
for db_adjustment in db_range:
|
908 |
-
for t in track2_flipped:
|
909 |
-
track2_adjusted = t * (10 ** (db_adjustment / 20))
|
910 |
-
corr = correlate(track1, track2_adjusted)
|
911 |
-
delay = np.argmax(np.abs(corr)) - (len(track1) - 1)
|
912 |
-
track2_shifted = np.roll(track2_adjusted, shift=delay)
|
913 |
-
|
914 |
-
track2_shifted_sub = track1 - track2_shifted
|
915 |
-
mean_abs_value = np.abs(track2_shifted_sub).mean()
|
916 |
-
|
917 |
-
shifted_tracks[mean_abs_value] = track2_shifted
|
918 |
-
|
919 |
-
return shifted_tracks[min(shifted_tracks.keys())]
|
920 |
-
|
921 |
-
assert mix.shape == instrumental.shape, translations["assert"].format(mixshape=mix.shape, instrumentalshape=instrumental.shape)
|
922 |
-
|
923 |
-
seconds_length = seconds_length // 2
|
924 |
-
|
925 |
-
sub_mapper = {}
|
926 |
-
|
927 |
-
progress_update_interval = 120
|
928 |
-
total_iterations = 0
|
929 |
-
|
930 |
-
if len(align_window) > 2: progress_update_interval = 320
|
931 |
-
|
932 |
-
for secs in align_window:
|
933 |
-
step = secs / 2
|
934 |
-
window_size = int(sr * secs)
|
935 |
-
step_size = int(sr * step)
|
936 |
-
|
937 |
-
if len(mix.shape) == 1:
|
938 |
-
total_mono = (len(range(0, len(mix) - window_size, step_size)) // progress_update_interval) * unique_sources
|
939 |
-
total_iterations += total_mono
|
940 |
-
else:
|
941 |
-
total_stereo_ = len(range(0, len(mix[:, 0]) - window_size, step_size)) * 2
|
942 |
-
total_stereo = (total_stereo_ // progress_update_interval) * unique_sources
|
943 |
-
total_iterations += total_stereo
|
944 |
-
|
945 |
-
for secs in align_window:
|
946 |
-
sub = np.zeros_like(mix)
|
947 |
-
divider = np.zeros_like(mix)
|
948 |
-
step = secs / 2
|
949 |
-
window_size = int(sr * secs)
|
950 |
-
step_size = int(sr * step)
|
951 |
-
window = np.hanning(window_size)
|
952 |
-
|
953 |
-
if len(mix.shape) == 1:
|
954 |
-
counter = 0
|
955 |
-
|
956 |
-
for i in range(0, len(mix) - window_size, step_size):
|
957 |
-
counter += 1
|
958 |
-
|
959 |
-
if counter % progress_update_interval == 0: progress_bar(total_iterations)
|
960 |
-
|
961 |
-
window_mix = mix[i : i + window_size] * window
|
962 |
-
window_instrumental = instrumental[i : i + window_size] * window
|
963 |
-
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
964 |
-
|
965 |
-
sub[i : i + window_size] += window_mix - window_instrumental_aligned
|
966 |
-
divider[i : i + window_size] += window
|
967 |
-
else:
|
968 |
-
counter = 0
|
969 |
-
|
970 |
-
for ch in range(mix.shape[1]):
|
971 |
-
for i in range(0, len(mix[:, ch]) - window_size, step_size):
|
972 |
-
counter += 1
|
973 |
-
|
974 |
-
if counter % progress_update_interval == 0: progress_bar(total_iterations)
|
975 |
-
|
976 |
-
window_mix = mix[i : i + window_size, ch] * window
|
977 |
-
window_instrumental = instrumental[i : i + window_size, ch] * window
|
978 |
-
window_instrumental_aligned = align_tracks(window_mix, window_instrumental)
|
979 |
-
|
980 |
-
sub[i : i + window_size, ch] += window_mix - window_instrumental_aligned
|
981 |
-
divider[i : i + window_size, ch] += window
|
982 |
-
|
983 |
-
sub = np.where(divider > 1e-6, sub / divider, sub)
|
984 |
-
sub_size = np.abs(sub).mean()
|
985 |
-
sub_mapper = {**sub_mapper, **{sub_size: sub}}
|
986 |
-
|
987 |
-
sub = ensemble_wav(list(sub_mapper.values()), split_size=12)
|
988 |
-
|
989 |
-
return sub
|
990 |
-
|
991 |
-
def ensemble_wav(waveforms, split_size=240):
|
992 |
-
waveform_thirds = {i: np.array_split(waveform, split_size) for i, waveform in enumerate(waveforms)}
|
993 |
-
|
994 |
-
final_waveform = []
|
995 |
-
|
996 |
-
for third_idx in range(split_size):
|
997 |
-
means = [np.abs(waveform_thirds[i][third_idx]).mean() for i in range(len(waveforms))]
|
998 |
-
|
999 |
-
min_index = np.argmin(means)
|
1000 |
-
|
1001 |
-
final_waveform.append(waveform_thirds[min_index][third_idx])
|
1002 |
-
|
1003 |
-
final_waveform = np.concatenate(final_waveform)
|
1004 |
-
|
1005 |
-
return final_waveform
|
1006 |
-
|
1007 |
-
def ensemble_wav_min(waveforms):
|
1008 |
-
for i in range(1, len(waveforms)):
|
1009 |
-
if i == 1: wave = waveforms[0]
|
1010 |
-
|
1011 |
-
ln = min(len(wave), len(waveforms[i]))
|
1012 |
-
wave = wave[:ln]
|
1013 |
-
waveforms[i] = waveforms[i][:ln]
|
1014 |
-
wave = np.where(np.abs(waveforms[i]) <= np.abs(wave), waveforms[i], wave)
|
1015 |
-
|
1016 |
-
return wave
|
1017 |
-
|
1018 |
-
def align_audio_test(wav1, wav2, sr1=44100):
|
1019 |
-
def get_diff(a, b):
|
1020 |
-
corr = np.correlate(a, b, "full")
|
1021 |
-
diff = corr.argmax() - (b.shape[0] - 1)
|
1022 |
-
return diff
|
1023 |
-
|
1024 |
-
wav1 = wav1.transpose()
|
1025 |
-
wav2 = wav2.transpose()
|
1026 |
-
wav2_org = wav2.copy()
|
1027 |
-
|
1028 |
-
index = sr1
|
1029 |
-
samp1 = wav1[index : index + sr1, 0]
|
1030 |
-
samp2 = wav2[index : index + sr1, 0]
|
1031 |
-
diff = get_diff(samp1, samp2)
|
1032 |
-
|
1033 |
-
if diff > 0: wav2_aligned = np.append(np.zeros((diff, 1)), wav2_org, axis=0)
|
1034 |
-
elif diff < 0: wav2_aligned = wav2_org[-diff:]
|
1035 |
-
else: wav2_aligned = wav2_org
|
1036 |
-
|
1037 |
-
return wav2_aligned
|
1038 |
-
|
1039 |
-
def load_audio(audio_file):
|
1040 |
-
wav, sr = librosa.load(audio_file, sr=44100, mono=False)
|
1041 |
-
if wav.ndim == 1: wav = np.asfortranarray([wav, wav])
|
1042 |
-
|
1043 |
-
return wav
|
1044 |
-
|
1045 |
-
def rerun_mp3(audio_file):
|
1046 |
-
with audioread.audio_open(audio_file) as f:
|
1047 |
-
track_length = int(f.duration)
|
1048 |
-
|
1049 |
-
return track_length
|
1050 |
-
|
1051 |
-
def __rubberband(y, sr, **kwargs):
|
1052 |
-
assert sr > 0
|
1053 |
-
|
1054 |
-
fd, infile = tempfile.mkstemp(suffix='.wav')
|
1055 |
-
os.close(fd)
|
1056 |
-
fd, outfile = tempfile.mkstemp(suffix='.wav')
|
1057 |
-
os.close(fd)
|
1058 |
-
|
1059 |
-
sf.write(infile, y, sr)
|
1060 |
-
|
1061 |
-
try:
|
1062 |
-
arguments = [os.path.join(BASE_PATH_RUB, 'rubberband'), '-q']
|
1063 |
-
|
1064 |
-
for key, value in six.iteritems(kwargs):
|
1065 |
-
arguments.append(str(key))
|
1066 |
-
arguments.append(str(value))
|
1067 |
-
|
1068 |
-
arguments.extend([infile, outfile])
|
1069 |
-
|
1070 |
-
subprocess.check_call(arguments, stdout=DEVNULL, stderr=DEVNULL)
|
1071 |
-
|
1072 |
-
y_out, _ = sf.read(outfile, always_2d=True)
|
1073 |
-
|
1074 |
-
if y.ndim == 1: y_out = np.squeeze(y_out)
|
1075 |
-
except OSError as exc:
|
1076 |
-
six.raise_from(RuntimeError(translations["rubberband"]), exc)
|
1077 |
-
finally:
|
1078 |
-
os.unlink(infile)
|
1079 |
-
os.unlink(outfile)
|
1080 |
-
|
1081 |
-
return y_out
|
1082 |
-
|
1083 |
-
def time_stretch(y, sr, rate, rbargs=None):
|
1084 |
-
if rate <= 0: raise ValueError(translations["rate"])
|
1085 |
-
|
1086 |
-
if rate == 1.0: return y
|
1087 |
-
if rbargs is None: rbargs = dict()
|
1088 |
-
|
1089 |
-
rbargs.setdefault('--tempo', rate)
|
1090 |
-
|
1091 |
-
return __rubberband(y, sr, **rbargs)
|
1092 |
-
|
1093 |
-
def pitch_shift(y, sr, n_steps, rbargs=None):
|
1094 |
-
|
1095 |
-
if n_steps == 0: return y
|
1096 |
-
if rbargs is None: rbargs = dict()
|
1097 |
-
|
1098 |
-
rbargs.setdefault('--pitch', n_steps)
|
1099 |
-
|
1100 |
-
return __rubberband(y, sr, **rbargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/gdown.py
DELETED
@@ -1,230 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import six
|
5 |
-
import json
|
6 |
-
import tqdm
|
7 |
-
import shutil
|
8 |
-
import tempfile
|
9 |
-
import requests
|
10 |
-
import warnings
|
11 |
-
import textwrap
|
12 |
-
|
13 |
-
from time import sleep, time
|
14 |
-
from urllib.parse import urlparse, parse_qs, unquote
|
15 |
-
|
16 |
-
now_dir = os.getcwd()
|
17 |
-
sys.path.append(now_dir)
|
18 |
-
|
19 |
-
from main.configs.config import Config
|
20 |
-
translations = Config().translations
|
21 |
-
|
22 |
-
CHUNK_SIZE = 512 * 1024
|
23 |
-
HOME = os.path.expanduser("~")
|
24 |
-
|
25 |
-
def indent(text, prefix):
|
26 |
-
return "".join((prefix + line if line.strip() else line) for line in text.splitlines(True))
|
27 |
-
|
28 |
-
|
29 |
-
def parse_url(url, warning=True):
|
30 |
-
parsed = urlparse(url)
|
31 |
-
is_download_link = parsed.path.endswith("/uc")
|
32 |
-
|
33 |
-
if not parsed.hostname in ("drive.google.com", "docs.google.com"): return None, is_download_link
|
34 |
-
|
35 |
-
file_id = parse_qs(parsed.query).get("id", [None])[0]
|
36 |
-
|
37 |
-
if file_id is None:
|
38 |
-
for pattern in (r"^/file/d/(.*?)/(edit|view)$", r"^/file/u/[0-9]+/d/(.*?)/(edit|view)$", r"^/document/d/(.*?)/(edit|htmlview|view)$", r"^/document/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/d/(.*?)/(edit|htmlview|view)$", r"^/presentation/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/d/(.*?)/(edit|htmlview|view)$", r"^/spreadsheets/u/[0-9]+/d/(.*?)/(edit|htmlview|view)$"):
|
39 |
-
match = re.match(pattern, parsed.path)
|
40 |
-
|
41 |
-
if match:
|
42 |
-
file_id = match.group(1)
|
43 |
-
break
|
44 |
-
|
45 |
-
if warning and not is_download_link:
|
46 |
-
warnings.warn(translations["gdown_warning"].format(file_id=file_id))
|
47 |
-
|
48 |
-
return file_id, is_download_link
|
49 |
-
|
50 |
-
|
51 |
-
def get_url_from_gdrive_confirmation(contents):
|
52 |
-
for pattern in (r'href="(\/uc\?export=download[^"]+)', r'href="/open\?id=([^"]+)"', r'"downloadUrl":"([^"]+)'):
|
53 |
-
match = re.search(pattern, contents)
|
54 |
-
|
55 |
-
if match:
|
56 |
-
url = match.group(1)
|
57 |
-
|
58 |
-
if pattern == r'href="/open\?id=([^"]+)"': url = ("https://drive.usercontent.google.com/download?id=" + url + "&confirm=t&uuid=" + re.search(r'<input\s+type="hidden"\s+name="uuid"\s+value="([^"]+)"', contents).group(1))
|
59 |
-
elif pattern == r'"downloadUrl":"([^"]+)': url = url.replace("\\u003d", "=").replace("\\u0026", "&")
|
60 |
-
else: url = "https://docs.google.com" + url.replace("&", "&")
|
61 |
-
|
62 |
-
return url
|
63 |
-
|
64 |
-
match = re.search(r'<p class="uc-error-subcaption">(.*)</p>', contents)
|
65 |
-
|
66 |
-
if match:
|
67 |
-
error = match.group(1)
|
68 |
-
raise Exception(error)
|
69 |
-
|
70 |
-
raise Exception(translations["gdown_error"])
|
71 |
-
|
72 |
-
|
73 |
-
def _get_session(proxy, use_cookies, return_cookies_file=False):
|
74 |
-
sess = requests.session()
|
75 |
-
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
76 |
-
|
77 |
-
if proxy is not None:
|
78 |
-
sess.proxies = {"http": proxy, "https": proxy}
|
79 |
-
print("Using proxy:", proxy, file=sys.stderr)
|
80 |
-
|
81 |
-
cookies_file = os.path.join(HOME, ".cache/gdown/cookies.json")
|
82 |
-
|
83 |
-
if os.path.exists(cookies_file) and use_cookies:
|
84 |
-
with open(cookies_file) as f:
|
85 |
-
cookies = json.load(f)
|
86 |
-
|
87 |
-
for k, v in cookies:
|
88 |
-
sess.cookies[k] = v
|
89 |
-
|
90 |
-
return (sess, cookies_file) if return_cookies_file else sess
|
91 |
-
|
92 |
-
|
93 |
-
def gdown_download(url=None, output=None, output_dir=None, quiet=False, proxy=None, speed=None, use_cookies=True, verify=True, id=None, fuzzy=True, resume=False, format=None):
|
94 |
-
if not (id is None) ^ (url is None): raise ValueError(translations["gdown_value_error"])
|
95 |
-
if id is not None: url = f"https://drive.google.com/uc?id={id}"
|
96 |
-
|
97 |
-
url_origin = url
|
98 |
-
|
99 |
-
sess, cookies_file = _get_session(proxy=proxy, use_cookies=use_cookies, return_cookies_file=True)
|
100 |
-
|
101 |
-
gdrive_file_id, is_gdrive_download_link = parse_url(url, warning=not fuzzy)
|
102 |
-
|
103 |
-
|
104 |
-
if fuzzy and gdrive_file_id:
|
105 |
-
url = f"https://drive.google.com/uc?id={gdrive_file_id}"
|
106 |
-
url_origin = url
|
107 |
-
is_gdrive_download_link = True
|
108 |
-
|
109 |
-
while 1:
|
110 |
-
res = sess.get(url, stream=True, verify=verify)
|
111 |
-
|
112 |
-
if url == url_origin and res.status_code == 500:
|
113 |
-
url = f"https://drive.google.com/open?id={gdrive_file_id}"
|
114 |
-
continue
|
115 |
-
|
116 |
-
if res.headers["Content-Type"].startswith("text/html"):
|
117 |
-
title = re.search("<title>(.+)</title>", res.text)
|
118 |
-
|
119 |
-
if title:
|
120 |
-
title = title.group(1)
|
121 |
-
if title.endswith(" - Google Docs"):
|
122 |
-
url = f"https://docs.google.com/document/d/{gdrive_file_id}/export?format={'docx' if format is None else format}"
|
123 |
-
continue
|
124 |
-
if title.endswith(" - Google Sheets"):
|
125 |
-
url = f"https://docs.google.com/spreadsheets/d/{gdrive_file_id}/export?format={'xlsx' if format is None else format}"
|
126 |
-
continue
|
127 |
-
if title.endswith(" - Google Slides"):
|
128 |
-
url = f"https://docs.google.com/presentation/d/{gdrive_file_id}/export?format={'pptx' if format is None else format}"
|
129 |
-
continue
|
130 |
-
elif ("Content-Disposition" in res.headers and res.headers["Content-Disposition"].endswith("pptx") and format not in (None, "pptx")):
|
131 |
-
url = f"https://docs.google.com/presentation/d/{gdrive_file_id}/export?format={'pptx' if format is None else format}"
|
132 |
-
continue
|
133 |
-
|
134 |
-
if use_cookies:
|
135 |
-
os.makedirs(os.path.dirname(cookies_file), exist_ok=True)
|
136 |
-
|
137 |
-
with open(cookies_file, "w") as f:
|
138 |
-
cookies = [(k, v) for k, v in sess.cookies.items() if not k.startswith("download_warning_")]
|
139 |
-
json.dump(cookies, f, indent=2)
|
140 |
-
|
141 |
-
if "Content-Disposition" in res.headers: break
|
142 |
-
if not (gdrive_file_id and is_gdrive_download_link): break
|
143 |
-
|
144 |
-
|
145 |
-
try:
|
146 |
-
url = get_url_from_gdrive_confirmation(res.text)
|
147 |
-
except Exception as e:
|
148 |
-
error = indent("\n".join(textwrap.wrap(str(e))), prefix="\t")
|
149 |
-
raise Exception(translations["gdown_error_2"].format(error=error, url_origin=url_origin))
|
150 |
-
|
151 |
-
if gdrive_file_id and is_gdrive_download_link:
|
152 |
-
content_disposition = unquote(res.headers["Content-Disposition"])
|
153 |
-
filename_from_url = (re.search(r"filename\*=UTF-8''(.*)", content_disposition) or re.search(r'filename=["\']?(.*?)["\']?$', content_disposition)).group(1)
|
154 |
-
filename_from_url = filename_from_url.replace(os.path.sep, "_")
|
155 |
-
else: filename_from_url = os.path.basename(url)
|
156 |
-
|
157 |
-
output = output or filename_from_url
|
158 |
-
output_is_path = isinstance(output, six.string_types)
|
159 |
-
|
160 |
-
if output_is_path and output.endswith(os.path.sep):
|
161 |
-
os.makedirs(output, exist_ok=True)
|
162 |
-
output = os.path.join(output, filename_from_url)
|
163 |
-
|
164 |
-
if output_is_path:
|
165 |
-
temp_dir = os.path.dirname(output) or "."
|
166 |
-
prefix = os.path.basename(output)
|
167 |
-
existing_tmp_files = [os.path.join(temp_dir, file) for file in os.listdir(temp_dir) if file.startswith(prefix)]
|
168 |
-
|
169 |
-
if resume and existing_tmp_files:
|
170 |
-
if len(existing_tmp_files) > 1:
|
171 |
-
print(translations["temps"], file=sys.stderr)
|
172 |
-
|
173 |
-
for file in existing_tmp_files:
|
174 |
-
print(f"\t{file}", file=sys.stderr)
|
175 |
-
|
176 |
-
print(translations["del_all_temps"], file=sys.stderr)
|
177 |
-
return
|
178 |
-
|
179 |
-
tmp_file = existing_tmp_files[0]
|
180 |
-
else:
|
181 |
-
resume = False
|
182 |
-
tmp_file = tempfile.mktemp(suffix=tempfile.template, prefix=prefix, dir=temp_dir)
|
183 |
-
|
184 |
-
f = open(tmp_file, "ab")
|
185 |
-
else:
|
186 |
-
tmp_file = None
|
187 |
-
f = output
|
188 |
-
|
189 |
-
|
190 |
-
if tmp_file is not None and f.tell() != 0: res = sess.get(url, headers={"Range": f"bytes={f.tell()}-"}, stream=True, verify=verify)
|
191 |
-
|
192 |
-
if not quiet:
|
193 |
-
if resume: print(translations["continue"], tmp_file, file=sys.stderr)
|
194 |
-
|
195 |
-
print(translations["to"], os.path.abspath(output) if output_is_path else output, file=sys.stderr)
|
196 |
-
|
197 |
-
try:
|
198 |
-
if not quiet: pbar = tqdm.tqdm(total=int(res.headers.get("Content-Length", 0)))
|
199 |
-
|
200 |
-
t_start = time()
|
201 |
-
|
202 |
-
for chunk in res.iter_content(chunk_size=CHUNK_SIZE):
|
203 |
-
f.write(chunk)
|
204 |
-
|
205 |
-
if not quiet: pbar.update(len(chunk))
|
206 |
-
|
207 |
-
if speed is not None:
|
208 |
-
elapsed_time_expected = 1.0 * pbar.n / speed
|
209 |
-
elapsed_time = time() - t_start
|
210 |
-
|
211 |
-
if elapsed_time < elapsed_time_expected: sleep(elapsed_time_expected - elapsed_time)
|
212 |
-
|
213 |
-
if not quiet: pbar.close()
|
214 |
-
|
215 |
-
if tmp_file:
|
216 |
-
f.close()
|
217 |
-
|
218 |
-
if output_dir is not None:
|
219 |
-
output_file = os.path.join(output_dir, output)
|
220 |
-
if os.path.exists(output_file): os.remove(output_file)
|
221 |
-
|
222 |
-
shutil.move(tmp_file, output_file)
|
223 |
-
else:
|
224 |
-
if os.path.exists(output): os.remove(output)
|
225 |
-
|
226 |
-
shutil.move(tmp_file, output)
|
227 |
-
finally:
|
228 |
-
sess.close()
|
229 |
-
|
230 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/mediafire.py
DELETED
@@ -1,30 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import sys
|
3 |
-
import requests
|
4 |
-
from bs4 import BeautifulSoup
|
5 |
-
|
6 |
-
def Mediafire_Download(url, output=None, filename=None):
|
7 |
-
if not filename: filename = url.split('/')[-2]
|
8 |
-
if not output: output = os.path.dirname(os.path.realpath(__file__))
|
9 |
-
output_file = os.path.join(output, filename)
|
10 |
-
|
11 |
-
sess = requests.session()
|
12 |
-
sess.headers.update({"User-Agent": "Mozilla/5.0 (Macintosh; Intel Mac OS X 10_12_6)"})
|
13 |
-
|
14 |
-
try:
|
15 |
-
with requests.get(BeautifulSoup(sess.get(url).content, "html.parser").find(id="downloadButton").get("href"), stream=True) as r:
|
16 |
-
r.raise_for_status()
|
17 |
-
with open(output_file, "wb") as f:
|
18 |
-
total_length = r.headers.get('content-length')
|
19 |
-
total_length = int(total_length)
|
20 |
-
download_progress = 0
|
21 |
-
|
22 |
-
for chunk in r.iter_content(chunk_size=1024):
|
23 |
-
download_progress += len(chunk)
|
24 |
-
f.write(chunk)
|
25 |
-
sys.stdout.write(f"\r[{filename}]: {int(100 * download_progress/total_length)}% ({round(download_progress/1024/1024, 2)}mb/{round(total_length/1024/1024, 2)}mb)")
|
26 |
-
sys.stdout.flush()
|
27 |
-
sys.stdout.write("\n")
|
28 |
-
return output_file
|
29 |
-
except Exception as e:
|
30 |
-
raise RuntimeError(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/meganz.py
DELETED
@@ -1,180 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import re
|
3 |
-
import sys
|
4 |
-
import json
|
5 |
-
import tqdm
|
6 |
-
import codecs
|
7 |
-
import random
|
8 |
-
import base64
|
9 |
-
import struct
|
10 |
-
import shutil
|
11 |
-
import requests
|
12 |
-
import tempfile
|
13 |
-
|
14 |
-
from Crypto.Cipher import AES
|
15 |
-
from Crypto.Util import Counter
|
16 |
-
from tenacity import retry, wait_exponential, retry_if_exception_type
|
17 |
-
|
18 |
-
|
19 |
-
now_dir = os.getcwd()
|
20 |
-
sys.path.append(now_dir)
|
21 |
-
|
22 |
-
from main.configs.config import Config
|
23 |
-
translations = Config().translations
|
24 |
-
|
25 |
-
|
26 |
-
def makebyte(x):
|
27 |
-
return codecs.latin_1_encode(x)[0]
|
28 |
-
|
29 |
-
|
30 |
-
def a32_to_str(a):
|
31 |
-
return struct.pack('>%dI' % len(a), *a)
|
32 |
-
|
33 |
-
|
34 |
-
def get_chunks(size):
|
35 |
-
p = 0
|
36 |
-
s = 0x20000
|
37 |
-
|
38 |
-
while p + s < size:
|
39 |
-
yield (p, s)
|
40 |
-
p += s
|
41 |
-
|
42 |
-
if s < 0x100000: s += 0x20000
|
43 |
-
|
44 |
-
yield (p, size - p)
|
45 |
-
|
46 |
-
|
47 |
-
def decrypt_attr(attr, key):
|
48 |
-
attr = AES.new(a32_to_str(key), AES.MODE_CBC, makebyte('\0' * 16)).decrypt(attr)
|
49 |
-
attr = codecs.latin_1_decode(attr)[0]
|
50 |
-
attr = attr.rstrip('\0')
|
51 |
-
|
52 |
-
return json.loads(attr[4:]) if attr[:6] == 'MEGA{"' else False
|
53 |
-
|
54 |
-
|
55 |
-
@retry(retry=retry_if_exception_type(RuntimeError), wait=wait_exponential(multiplier=2, min=2, max=60))
|
56 |
-
def _api_request(data):
|
57 |
-
sequence_num = random.randint(0, 0xFFFFFFFF)
|
58 |
-
params = {'id': sequence_num}
|
59 |
-
sequence_num += 1
|
60 |
-
|
61 |
-
if not isinstance(data, list): data = [data]
|
62 |
-
|
63 |
-
json_resp = json.loads(requests.post(f'https://g.api.mega.co.nz/cs', params=params, data=json.dumps(data), timeout=160).text)
|
64 |
-
|
65 |
-
|
66 |
-
try:
|
67 |
-
if isinstance(json_resp, list): int_resp = json_resp[0] if isinstance(json_resp[0], int) else None
|
68 |
-
elif isinstance(json_resp, int): int_resp = json_resp
|
69 |
-
except IndexError:
|
70 |
-
int_resp = None
|
71 |
-
|
72 |
-
if int_resp is not None:
|
73 |
-
if int_resp == 0: return int_resp
|
74 |
-
if int_resp == -3: raise RuntimeError('int_resp==-3')
|
75 |
-
|
76 |
-
raise Exception(int_resp)
|
77 |
-
|
78 |
-
return json_resp[0]
|
79 |
-
|
80 |
-
|
81 |
-
def base64_url_decode(data):
|
82 |
-
data += '=='[(2 - len(data) * 3) % 4:]
|
83 |
-
|
84 |
-
for search, replace in (('-', '+'), ('_', '/'), (',', '')):
|
85 |
-
data = data.replace(search, replace)
|
86 |
-
|
87 |
-
return base64.b64decode(data)
|
88 |
-
|
89 |
-
|
90 |
-
def str_to_a32(b):
|
91 |
-
if isinstance(b, str): b = makebyte(b)
|
92 |
-
if len(b) % 4: b += b'\0' * (4 - len(b) % 4)
|
93 |
-
|
94 |
-
return struct.unpack('>%dI' % (len(b) / 4), b)
|
95 |
-
|
96 |
-
|
97 |
-
def mega_download_file(file_handle, file_key, dest_path=None, dest_filename=None, file=None):
|
98 |
-
if file is None:
|
99 |
-
file_key = str_to_a32(base64_url_decode(file_key))
|
100 |
-
file_data = _api_request({'a': 'g', 'g': 1, 'p': file_handle})
|
101 |
-
|
102 |
-
k = (file_key[0] ^ file_key[4], file_key[1] ^ file_key[5], file_key[2] ^ file_key[6], file_key[3] ^ file_key[7])
|
103 |
-
iv = file_key[4:6] + (0, 0)
|
104 |
-
meta_mac = file_key[6:8]
|
105 |
-
else:
|
106 |
-
file_data = _api_request({'a': 'g', 'g': 1, 'n': file['h']})
|
107 |
-
k = file['k']
|
108 |
-
iv = file['iv']
|
109 |
-
meta_mac = file['meta_mac']
|
110 |
-
|
111 |
-
if 'g' not in file_data: raise Exception(translations["file_not_access"])
|
112 |
-
|
113 |
-
file_size = file_data['s']
|
114 |
-
|
115 |
-
attribs = base64_url_decode(file_data['at'])
|
116 |
-
attribs = decrypt_attr(attribs, k)
|
117 |
-
|
118 |
-
file_name = dest_filename if dest_filename is not None else attribs['n']
|
119 |
-
|
120 |
-
input_file = requests.get(file_data['g'], stream=True).raw
|
121 |
-
|
122 |
-
if dest_path is None: dest_path = ''
|
123 |
-
else: dest_path += '/'
|
124 |
-
|
125 |
-
temp_output_file = tempfile.NamedTemporaryFile(mode='w+b', prefix='megapy_', delete=False)
|
126 |
-
|
127 |
-
k_str = a32_to_str(k)
|
128 |
-
|
129 |
-
counter = Counter.new(128, initial_value=((iv[0] << 32) + iv[1]) << 64)
|
130 |
-
aes = AES.new(k_str, AES.MODE_CTR, counter=counter)
|
131 |
-
|
132 |
-
mac_str = b'\0' * 16
|
133 |
-
mac_encryptor = AES.new(k_str, AES.MODE_CBC, mac_str)
|
134 |
-
|
135 |
-
iv_str = a32_to_str([iv[0], iv[1], iv[0], iv[1]])
|
136 |
-
|
137 |
-
pbar = tqdm.tqdm(total=file_size)
|
138 |
-
|
139 |
-
for _, chunk_size in get_chunks(file_size):
|
140 |
-
chunk = input_file.read(chunk_size)
|
141 |
-
chunk = aes.decrypt(chunk)
|
142 |
-
temp_output_file.write(chunk)
|
143 |
-
|
144 |
-
pbar.update(len(chunk))
|
145 |
-
|
146 |
-
encryptor = AES.new(k_str, AES.MODE_CBC, iv_str)
|
147 |
-
|
148 |
-
for i in range(0, len(chunk)-16, 16):
|
149 |
-
block = chunk[i:i + 16]
|
150 |
-
encryptor.encrypt(block)
|
151 |
-
|
152 |
-
if file_size > 16: i += 16
|
153 |
-
else: i = 0
|
154 |
-
|
155 |
-
block = chunk[i:i + 16]
|
156 |
-
if len(block) % 16: block += b'\0' * (16 - (len(block) % 16))
|
157 |
-
|
158 |
-
mac_str = mac_encryptor.encrypt(encryptor.encrypt(block))
|
159 |
-
|
160 |
-
file_mac = str_to_a32(mac_str)
|
161 |
-
temp_output_file.close()
|
162 |
-
|
163 |
-
if (file_mac[0] ^ file_mac[1], file_mac[2] ^ file_mac[3]) != meta_mac: raise ValueError(translations["mac_not_match"])
|
164 |
-
|
165 |
-
file_path = os.path.join(dest_path, file_name)
|
166 |
-
if os.path.exists(file_path): os.remove(file_path)
|
167 |
-
|
168 |
-
shutil.move(temp_output_file.name, file_path)
|
169 |
-
|
170 |
-
|
171 |
-
def mega_download_url(url, dest_path=None, dest_filename=None):
|
172 |
-
if '/file/' in url:
|
173 |
-
url = url.replace(' ', '')
|
174 |
-
file_id = re.findall(r'\W\w\w\w\w\w\w\w\w\W', url)[0][1:-1]
|
175 |
-
|
176 |
-
path = f'{file_id}!{url[re.search(file_id, url).end() + 1:]}'.split('!')
|
177 |
-
elif '!' in url: path = re.findall(r'/#!(.*)', url)[0].split('!')
|
178 |
-
else: raise Exception(translations["missing_url"])
|
179 |
-
|
180 |
-
return mega_download_file(file_handle=path[0], file_key=path[1], dest_path=dest_path, dest_filename=dest_filename)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
main/tools/pixeldrain.py
DELETED
@@ -1,19 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import requests
|
3 |
-
|
4 |
-
|
5 |
-
def pixeldrain(url, output_dir):
|
6 |
-
try:
|
7 |
-
file_id = url.split("pixeldrain.com/u/")[1]
|
8 |
-
response = requests.get(f"https://pixeldrain.com/api/file/{file_id}")
|
9 |
-
|
10 |
-
if response.status_code == 200:
|
11 |
-
file_name = (response.headers.get("Content-Disposition").split("filename=")[-1].strip('";'))
|
12 |
-
file_path = os.path.join(output_dir, file_name)
|
13 |
-
|
14 |
-
with open(file_path, "wb") as newfile:
|
15 |
-
newfile.write(response.content)
|
16 |
-
return file_path
|
17 |
-
else: return None
|
18 |
-
except Exception as e:
|
19 |
-
raise RuntimeError(e)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|