Spaces:
Running
Running
Commit
·
bfa885e
0
Parent(s):
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +35 -0
- .gitignore +19 -0
- Dockerfile +21 -0
- README.md +11 -0
- examples/sample_filter/bad_case_find.py +84 -0
- examples/sample_filter/correction.py +70 -0
- examples/sample_filter/find_label_error_wav.py +77 -0
- examples/sample_filter/test2.py +78 -0
- examples/sample_filter/wav_find_by_task_excel.py +92 -0
- examples/vm_sound_classification/requirements.txt +10 -0
- examples/vm_sound_classification/run.sh +197 -0
- examples/vm_sound_classification/run_batch.sh +268 -0
- examples/vm_sound_classification/step_1_prepare_data.py +194 -0
- examples/vm_sound_classification/step_2_make_vocabulary.py +51 -0
- examples/vm_sound_classification/step_3_train_model.py +367 -0
- examples/vm_sound_classification/step_4_evaluation_model.py +128 -0
- examples/vm_sound_classification/step_5_export_models.py +106 -0
- examples/vm_sound_classification/step_6_infer.py +91 -0
- examples/vm_sound_classification/step_7_test_model.py +93 -0
- examples/vm_sound_classification/stop.sh +3 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-2-ch16.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-2-ch32.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-2-ch4.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-2-ch8.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-3-ch16.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-3-ch32.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-3-ch4.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-3-ch8.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-4-ch16.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-4-ch32.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-4-ch4.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-4-ch8.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-8-ch16.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-8-ch32.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-8-ch4.yaml +45 -0
- examples/vm_sound_classification/yaml/conv2d-classifier-8-ch8.yaml +45 -0
- examples/vm_sound_classification8/requirements.txt +9 -0
- examples/vm_sound_classification8/run.sh +157 -0
- examples/vm_sound_classification8/step_1_prepare_data.py +156 -0
- examples/vm_sound_classification8/step_2_make_vocabulary.py +69 -0
- examples/vm_sound_classification8/step_3_train_global_model.py +328 -0
- examples/vm_sound_classification8/step_4_train_country_model.py +349 -0
- examples/vm_sound_classification8/step_5_train_union.py +499 -0
- examples/vm_sound_classification8/stop.sh +3 -0
- install.sh +64 -0
- main.py +206 -0
- project_settings.py +19 -0
- requirements.txt +13 -0
- script/install_nvidia_driver.sh +184 -0
- script/install_python.sh +129 -0
.gitattributes
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
.git/
|
3 |
+
.idea/
|
4 |
+
|
5 |
+
**/file_dir
|
6 |
+
**/flagged/
|
7 |
+
**/log/
|
8 |
+
**/logs/
|
9 |
+
**/__pycache__/
|
10 |
+
|
11 |
+
/data/
|
12 |
+
/docs/
|
13 |
+
/dotenv/
|
14 |
+
/examples/**/*.wav
|
15 |
+
/trained_models/
|
16 |
+
/temp/
|
17 |
+
|
18 |
+
#**/*.wav
|
19 |
+
**/*.xlsx
|
Dockerfile
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
FROM python:3.8
|
2 |
+
|
3 |
+
WORKDIR /code
|
4 |
+
|
5 |
+
COPY . /code
|
6 |
+
|
7 |
+
RUN pip install --upgrade pip
|
8 |
+
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
9 |
+
|
10 |
+
RUN useradd -m -u 1000 user
|
11 |
+
|
12 |
+
USER user
|
13 |
+
|
14 |
+
ENV HOME=/home/user \
|
15 |
+
PATH=/home/user/.local/bin:$PATH
|
16 |
+
|
17 |
+
WORKDIR $HOME/app
|
18 |
+
|
19 |
+
COPY --chown=user . $HOME/app
|
20 |
+
|
21 |
+
CMD ["python3", "main.py"]
|
README.md
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: VM Sound Classification
|
3 |
+
emoji: 🐢
|
4 |
+
colorFrom: purple
|
5 |
+
colorTo: blue
|
6 |
+
sdk: docker
|
7 |
+
pinned: false
|
8 |
+
license: apache-2.0
|
9 |
+
---
|
10 |
+
|
11 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
examples/sample_filter/bad_case_find.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from gradio_client import Client, handle_file
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
|
11 |
+
def get_args():
|
12 |
+
parser = argparse.ArgumentParser()
|
13 |
+
parser.add_argument(
|
14 |
+
"--data_dir",
|
15 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\data",
|
16 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\us-3",
|
17 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
|
18 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\id",
|
19 |
+
type=str
|
20 |
+
)
|
21 |
+
parser.add_argument(
|
22 |
+
"--keep_dir",
|
23 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\keep",
|
24 |
+
type=str
|
25 |
+
)
|
26 |
+
parser.add_argument(
|
27 |
+
"--trash_dir",
|
28 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\trash",
|
29 |
+
type=str
|
30 |
+
)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
def main():
|
36 |
+
args = get_args()
|
37 |
+
|
38 |
+
data_dir = Path(args.data_dir)
|
39 |
+
keep_dir = Path(args.keep_dir)
|
40 |
+
keep_dir.mkdir(parents=True, exist_ok=True)
|
41 |
+
# trash_dir = Path(args.trash_dir)
|
42 |
+
# trash_dir.mkdir(parents=True, exist_ok=True)
|
43 |
+
|
44 |
+
client = Client("http://127.0.0.1:7864/")
|
45 |
+
|
46 |
+
for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
|
47 |
+
# if idx < 400:
|
48 |
+
# continue
|
49 |
+
filename = filename.as_posix()
|
50 |
+
|
51 |
+
label1, prob1 = client.predict(
|
52 |
+
audio=handle_file(filename),
|
53 |
+
# model_name="vm_sound_classification8-ch32",
|
54 |
+
model_name="voicemail-en-ph-2-ch4",
|
55 |
+
ground_true="Hello!!",
|
56 |
+
api_name="/click_button"
|
57 |
+
)
|
58 |
+
prob1 = float(prob1)
|
59 |
+
|
60 |
+
label2, prob2 = client.predict(
|
61 |
+
audio=handle_file(filename),
|
62 |
+
# model_name="vm_sound_classification8-ch32",
|
63 |
+
model_name="sound-8-ch32",
|
64 |
+
ground_true="Hello!!",
|
65 |
+
api_name="/click_button"
|
66 |
+
)
|
67 |
+
prob2 = float(prob2)
|
68 |
+
|
69 |
+
if label1 == "voicemail" and label2 in ("voicemail", "bell") and prob1 > 0.6:
|
70 |
+
pass
|
71 |
+
elif label1 == "non_voicemail" and label2 not in ("voicemail", "bell") and prob1 > 0.6:
|
72 |
+
pass
|
73 |
+
else:
|
74 |
+
print(f"label1: {label1}, prob1: {prob1}, label2: {label2}, prob2: {prob2}")
|
75 |
+
shutil.move(
|
76 |
+
filename,
|
77 |
+
keep_dir.as_posix(),
|
78 |
+
)
|
79 |
+
# exit(0)
|
80 |
+
return
|
81 |
+
|
82 |
+
|
83 |
+
if __name__ == '__main__':
|
84 |
+
main()
|
examples/sample_filter/correction.py
ADDED
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from gradio_client import Client, handle_file
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from project_settings import project_path
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--data_dir",
|
17 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-PH\wav_finished",
|
18 |
+
type=str
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--correction_dir",
|
22 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\correction",
|
23 |
+
type=str
|
24 |
+
)
|
25 |
+
args = parser.parse_args()
|
26 |
+
return args
|
27 |
+
|
28 |
+
|
29 |
+
def main():
|
30 |
+
args = get_args()
|
31 |
+
|
32 |
+
data_dir = Path(args.data_dir)
|
33 |
+
correction_dir = Path(args.correction_dir)
|
34 |
+
correction_dir.mkdir(parents=True, exist_ok=True)
|
35 |
+
|
36 |
+
client = Client("http://127.0.0.1:7864/")
|
37 |
+
|
38 |
+
for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
|
39 |
+
# if idx < 200:
|
40 |
+
# continue
|
41 |
+
ground_truth = filename.parts[-2]
|
42 |
+
filename = filename.as_posix()
|
43 |
+
|
44 |
+
label, prob = client.predict(
|
45 |
+
audio=handle_file(filename),
|
46 |
+
model_name="voicemail-en-ph-2-ch32",
|
47 |
+
ground_true="Hello!!",
|
48 |
+
api_name="/click_button"
|
49 |
+
)
|
50 |
+
prob = float(prob)
|
51 |
+
|
52 |
+
if label == "voicemail" and ground_truth in ("voicemail", "bell"):
|
53 |
+
pass
|
54 |
+
elif label == "non_voicemail" and ground_truth not in ("voicemail", "bell"):
|
55 |
+
pass
|
56 |
+
else:
|
57 |
+
print(f"ground_truth: {ground_truth}, label: {label}, prob: {prob}")
|
58 |
+
|
59 |
+
tgt_dir = correction_dir / ground_truth
|
60 |
+
tgt_dir.mkdir(parents=True, exist_ok=True)
|
61 |
+
shutil.move(
|
62 |
+
filename,
|
63 |
+
tgt_dir.as_posix(),
|
64 |
+
)
|
65 |
+
|
66 |
+
return
|
67 |
+
|
68 |
+
|
69 |
+
if __name__ == '__main__':
|
70 |
+
main()
|
examples/sample_filter/find_label_error_wav.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from gradio_client import Client, handle_file
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from project_settings import project_path
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--data_dir",
|
17 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\wav_finished",
|
18 |
+
type=str
|
19 |
+
)
|
20 |
+
parser.add_argument(
|
21 |
+
"--keep_dir",
|
22 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\keep",
|
23 |
+
type=str
|
24 |
+
)
|
25 |
+
parser.add_argument(
|
26 |
+
"--trash_dir",
|
27 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\trash",
|
28 |
+
type=str
|
29 |
+
)
|
30 |
+
args = parser.parse_args()
|
31 |
+
return args
|
32 |
+
|
33 |
+
|
34 |
+
def main():
|
35 |
+
args = get_args()
|
36 |
+
|
37 |
+
data_dir = Path(args.data_dir)
|
38 |
+
keep_dir = Path(args.keep_dir)
|
39 |
+
keep_dir.mkdir(parents=True, exist_ok=True)
|
40 |
+
trash_dir = Path(args.trash_dir)
|
41 |
+
# trash_dir.mkdir(parents=True, exist_ok=True)
|
42 |
+
|
43 |
+
client = Client("http://127.0.0.1:7864/")
|
44 |
+
|
45 |
+
for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
|
46 |
+
# if idx < 200:
|
47 |
+
# continue
|
48 |
+
ground_truth = filename.parts[-2]
|
49 |
+
filename = filename.as_posix()
|
50 |
+
|
51 |
+
label1, prob1 = client.predict(
|
52 |
+
audio=handle_file(filename),
|
53 |
+
# model_name="vm_sound_classification8-ch32",
|
54 |
+
model_name="voicemail-en-us-2-ch32",
|
55 |
+
ground_true="Hello!!",
|
56 |
+
api_name="/click_button"
|
57 |
+
)
|
58 |
+
prob1 = float(prob1)
|
59 |
+
print(f"label: {label1}, prob: {prob1}, ground_truth: {ground_truth}")
|
60 |
+
|
61 |
+
if label1 == "voicemail" and ground_truth in ("bell", "voicemail") and prob1 > 0.65:
|
62 |
+
pass
|
63 |
+
elif label1 == "non_voicemail" and ground_truth not in ("bell", "voicemail") and prob1 > 0.65:
|
64 |
+
pass
|
65 |
+
else:
|
66 |
+
tgt = keep_dir / ground_truth
|
67 |
+
tgt.mkdir(parents=True, exist_ok=True)
|
68 |
+
shutil.move(
|
69 |
+
filename,
|
70 |
+
tgt.as_posix(),
|
71 |
+
)
|
72 |
+
|
73 |
+
return
|
74 |
+
|
75 |
+
|
76 |
+
if __name__ == '__main__':
|
77 |
+
main()
|
examples/sample_filter/test2.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from pathlib import Path
|
5 |
+
import shutil
|
6 |
+
|
7 |
+
from gradio_client import Client, handle_file
|
8 |
+
from tqdm import tqdm
|
9 |
+
|
10 |
+
from project_settings import project_path
|
11 |
+
|
12 |
+
|
13 |
+
def get_args():
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument(
|
16 |
+
"--data_dir",
|
17 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\data-1",
|
18 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-2\temp\VoiceAppVoicemailDetection-1",
|
19 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-3\temp\VoiceAppVoicemailDetection-1",
|
20 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-4\temp\VoiceAppVoicemailDetection-1",
|
21 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
|
22 |
+
type=str
|
23 |
+
)
|
24 |
+
parser.add_argument(
|
25 |
+
"--keep_dir",
|
26 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\keep-3",
|
27 |
+
type=str
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--trash_dir",
|
31 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\trash",
|
32 |
+
type=str
|
33 |
+
)
|
34 |
+
args = parser.parse_args()
|
35 |
+
return args
|
36 |
+
|
37 |
+
|
38 |
+
def main():
|
39 |
+
args = get_args()
|
40 |
+
|
41 |
+
data_dir = Path(args.data_dir)
|
42 |
+
keep_dir = Path(args.keep_dir)
|
43 |
+
keep_dir.mkdir(parents=True, exist_ok=True)
|
44 |
+
trash_dir = Path(args.trash_dir)
|
45 |
+
trash_dir.mkdir(parents=True, exist_ok=True)
|
46 |
+
|
47 |
+
client = Client("http://127.0.0.1:7864/")
|
48 |
+
|
49 |
+
for idx, filename in tqdm(enumerate(data_dir.glob("*.wav"))):
|
50 |
+
if idx < 200:
|
51 |
+
continue
|
52 |
+
filename = filename.as_posix()
|
53 |
+
|
54 |
+
label1, prob1 = client.predict(
|
55 |
+
audio=handle_file(filename),
|
56 |
+
# model_name="vm_sound_classification8-ch32",
|
57 |
+
model_name="voicemail-ms-my-2-ch32",
|
58 |
+
ground_true="Hello!!",
|
59 |
+
api_name="/click_button"
|
60 |
+
)
|
61 |
+
prob1 = float(prob1)
|
62 |
+
print(f"label: {label1}, prob: {prob1}")
|
63 |
+
|
64 |
+
if label1 == "voicemail" and prob1 < 0.95:
|
65 |
+
shutil.move(
|
66 |
+
filename,
|
67 |
+
keep_dir.as_posix(),
|
68 |
+
)
|
69 |
+
elif label1 != "voicemail" and prob1 < 0.85:
|
70 |
+
shutil.move(
|
71 |
+
filename,
|
72 |
+
keep_dir.as_posix(),
|
73 |
+
)
|
74 |
+
return
|
75 |
+
|
76 |
+
|
77 |
+
if __name__ == '__main__':
|
78 |
+
main()
|
examples/sample_filter/wav_find_by_task_excel.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
from gradio_client import Client, handle_file
|
10 |
+
from tqdm import tqdm
|
11 |
+
|
12 |
+
from project_settings import project_path
|
13 |
+
|
14 |
+
|
15 |
+
task_file_str = """
|
16 |
+
task_DcTask_1_PH_LIVE_20250328_20250328-1.xlsx
|
17 |
+
task_DcTask_1_PH_LIVE_20250329_20250329-1.xlsx
|
18 |
+
task_DcTask_1_PH_LIVE_20250331_20250331-1.xlsx
|
19 |
+
task_DcTask_3_PH_LIVE_20250328_20250328-1.xlsx
|
20 |
+
task_DcTask_3_PH_LIVE_20250331_20250331-1.xlsx
|
21 |
+
task_DcTask_9_PH_LIVE_20250329_20250329-1.xlsx
|
22 |
+
task_DcTask_9_PH_LIVE_20250331_20250331-1.xlsx
|
23 |
+
"""
|
24 |
+
|
25 |
+
|
26 |
+
def get_args():
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
parser.add_argument(
|
29 |
+
"--task_file_str",
|
30 |
+
default=task_file_str,
|
31 |
+
type=str
|
32 |
+
)
|
33 |
+
parser.add_argument(
|
34 |
+
"--wav_dir",
|
35 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\phl",
|
36 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-2\temp\VoiceAppVoicemailDetection-1",
|
37 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-3\temp\VoiceAppVoicemailDetection-1",
|
38 |
+
# default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-4\temp\VoiceAppVoicemailDetection-1",
|
39 |
+
type=str
|
40 |
+
)
|
41 |
+
parser.add_argument(
|
42 |
+
"--output_dir",
|
43 |
+
default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
|
44 |
+
type=str
|
45 |
+
)
|
46 |
+
args = parser.parse_args()
|
47 |
+
return args
|
48 |
+
|
49 |
+
|
50 |
+
def main():
|
51 |
+
args = get_args()
|
52 |
+
wav_dir = Path(args.wav_dir)
|
53 |
+
output_dir = Path(args.output_dir)
|
54 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
55 |
+
|
56 |
+
task_file_list = task_file_str.split("\n")
|
57 |
+
task_file_list = [task_file for task_file in task_file_list if len(task_file.strip()) != 0]
|
58 |
+
print(f"task_file_list: {task_file_list}")
|
59 |
+
|
60 |
+
for task_file in task_file_list:
|
61 |
+
df = pd.read_excel(task_file)
|
62 |
+
|
63 |
+
transfer_set = set()
|
64 |
+
for i, row in df.iterrows():
|
65 |
+
call_id = row["通话ID"]
|
66 |
+
intent_str = row["意向标签"]
|
67 |
+
if intent_str == "Connection - Transferred to agent":
|
68 |
+
transfer_set.add(call_id)
|
69 |
+
if intent_str == "Connection - No human voice detected":
|
70 |
+
transfer_set.add(call_id)
|
71 |
+
|
72 |
+
print(f"transfer count: {len(transfer_set)}")
|
73 |
+
|
74 |
+
for idx, filename in tqdm(enumerate(wav_dir.glob("**/*.wav"))):
|
75 |
+
|
76 |
+
basename = filename.stem
|
77 |
+
call_id, _, _, _ = basename.split("_")
|
78 |
+
|
79 |
+
if call_id not in transfer_set:
|
80 |
+
continue
|
81 |
+
|
82 |
+
print(filename.as_posix())
|
83 |
+
shutil.move(
|
84 |
+
filename.as_posix(),
|
85 |
+
output_dir.as_posix()
|
86 |
+
)
|
87 |
+
|
88 |
+
return
|
89 |
+
|
90 |
+
|
91 |
+
if __name__ == '__main__':
|
92 |
+
main()
|
examples/vm_sound_classification/requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.13.1
|
2 |
+
torchaudio==0.13.1
|
3 |
+
fsspec==2022.1.0
|
4 |
+
librosa==0.9.2
|
5 |
+
pandas==1.1.5
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.64.1
|
9 |
+
overrides==1.9.0
|
10 |
+
pyyaml==6.0.1
|
examples/vm_sound_classification/run.sh
ADDED
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 0 --stop_stage 1 --system_version windows --file_folder_name file_dir --final_model_name sound-4-ch32 \
|
6 |
+
--filename_patterns "E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/wav_finished/en-US/wav_finished/*/*.wav \
|
7 |
+
E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/id-ID/wav_finished/*/*.wav" \
|
8 |
+
--label_plan 4
|
9 |
+
|
10 |
+
sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name sound-2-ch32 \
|
11 |
+
--filename_patterns "E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/wav_finished/en-US/wav_finished/*/*.wav \
|
12 |
+
E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/id-ID/wav_finished/*/*.wav" \
|
13 |
+
--label_plan 4
|
14 |
+
|
15 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch32 \
|
16 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
17 |
+
--label_plan 3 \
|
18 |
+
--config_file "yaml/conv2d-classifier-3-ch4.yaml"
|
19 |
+
|
20 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch32 \
|
21 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
|
22 |
+
--label_plan 2-voicemail \
|
23 |
+
--config_file "yaml/conv2d-classifier-2-ch32.yaml"
|
24 |
+
|
25 |
+
END
|
26 |
+
|
27 |
+
|
28 |
+
# params
|
29 |
+
system_version="windows";
|
30 |
+
verbose=true;
|
31 |
+
stage=0 # start from 0 if you need to start from data preparation
|
32 |
+
stop_stage=9
|
33 |
+
|
34 |
+
work_dir="$(pwd)"
|
35 |
+
file_folder_name=file_folder_name
|
36 |
+
final_model_name=final_model_name
|
37 |
+
filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
38 |
+
label_plan=4
|
39 |
+
config_file="yaml/conv2d-classifier-2-ch4.yaml"
|
40 |
+
pretrained_model=null
|
41 |
+
nohup_name=nohup.out
|
42 |
+
|
43 |
+
country=en-US
|
44 |
+
|
45 |
+
# model params
|
46 |
+
batch_size=64
|
47 |
+
max_epochs=200
|
48 |
+
save_top_k=10
|
49 |
+
patience=5
|
50 |
+
|
51 |
+
|
52 |
+
# parse options
|
53 |
+
while true; do
|
54 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
55 |
+
case "$1" in
|
56 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
57 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
58 |
+
old_value="(eval echo \\$$name)";
|
59 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
60 |
+
was_bool=true;
|
61 |
+
else
|
62 |
+
was_bool=false;
|
63 |
+
fi
|
64 |
+
|
65 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
66 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
67 |
+
eval "${name}=\"$2\"";
|
68 |
+
|
69 |
+
# Check that Boolean-valued arguments are really Boolean.
|
70 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
71 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
72 |
+
exit 1;
|
73 |
+
fi
|
74 |
+
shift 2;
|
75 |
+
;;
|
76 |
+
|
77 |
+
*) break;
|
78 |
+
esac
|
79 |
+
done
|
80 |
+
|
81 |
+
file_dir="${work_dir}/${file_folder_name}"
|
82 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
83 |
+
|
84 |
+
dataset="${file_dir}/dataset.xlsx"
|
85 |
+
train_dataset="${file_dir}/train.xlsx"
|
86 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
87 |
+
evaluation_file="${file_dir}/evaluation.xlsx"
|
88 |
+
vocabulary_dir="${file_dir}/vocabulary"
|
89 |
+
|
90 |
+
$verbose && echo "system_version: ${system_version}"
|
91 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
92 |
+
|
93 |
+
if [ $system_version == "windows" ]; then
|
94 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
|
95 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
96 |
+
#source /data/local/bin/vm_sound_classification/bin/activate
|
97 |
+
alias python3='/data/local/bin/vm_sound_classification/bin/python3'
|
98 |
+
fi
|
99 |
+
|
100 |
+
|
101 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
102 |
+
$verbose && echo "stage 0: prepare data"
|
103 |
+
cd "${work_dir}" || exit 1
|
104 |
+
python3 step_1_prepare_data.py \
|
105 |
+
--file_dir "${file_dir}" \
|
106 |
+
--filename_patterns "${filename_patterns}" \
|
107 |
+
--train_dataset "${train_dataset}" \
|
108 |
+
--valid_dataset "${valid_dataset}" \
|
109 |
+
--label_plan "${label_plan}" \
|
110 |
+
|
111 |
+
fi
|
112 |
+
|
113 |
+
|
114 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
115 |
+
$verbose && echo "stage 1: make vocabulary"
|
116 |
+
cd "${work_dir}" || exit 1
|
117 |
+
python3 step_2_make_vocabulary.py \
|
118 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
119 |
+
--train_dataset "${train_dataset}" \
|
120 |
+
--valid_dataset "${valid_dataset}" \
|
121 |
+
|
122 |
+
fi
|
123 |
+
|
124 |
+
|
125 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
126 |
+
$verbose && echo "stage 2: train model"
|
127 |
+
cd "${work_dir}" || exit 1
|
128 |
+
python3 step_3_train_model.py \
|
129 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
130 |
+
--train_dataset "${train_dataset}" \
|
131 |
+
--valid_dataset "${valid_dataset}" \
|
132 |
+
--serialization_dir "${file_dir}" \
|
133 |
+
--config_file "${config_file}" \
|
134 |
+
--pretrained_model "${pretrained_model}" \
|
135 |
+
|
136 |
+
fi
|
137 |
+
|
138 |
+
|
139 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
140 |
+
$verbose && echo "stage 3: test model"
|
141 |
+
cd "${work_dir}" || exit 1
|
142 |
+
python3 step_4_evaluation_model.py \
|
143 |
+
--dataset "${dataset}" \
|
144 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
145 |
+
--model_dir "${file_dir}/best" \
|
146 |
+
--output_file "${evaluation_file}" \
|
147 |
+
|
148 |
+
fi
|
149 |
+
|
150 |
+
|
151 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
152 |
+
$verbose && echo "stage 4: export model"
|
153 |
+
cd "${work_dir}" || exit 1
|
154 |
+
python3 step_5_export_models.py \
|
155 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
156 |
+
--model_dir "${file_dir}/best" \
|
157 |
+
--serialization_dir "${file_dir}" \
|
158 |
+
|
159 |
+
fi
|
160 |
+
|
161 |
+
|
162 |
+
if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
163 |
+
$verbose && echo "stage 5: collect files"
|
164 |
+
cd "${work_dir}" || exit 1
|
165 |
+
|
166 |
+
mkdir -p ${final_model_dir}
|
167 |
+
|
168 |
+
cp "${file_dir}/best"/* "${final_model_dir}"
|
169 |
+
cp -r "${file_dir}/vocabulary" "${final_model_dir}"
|
170 |
+
|
171 |
+
cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
|
172 |
+
|
173 |
+
cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
|
174 |
+
cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
|
175 |
+
cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
|
176 |
+
cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
|
177 |
+
|
178 |
+
cd "${final_model_dir}/.." || exit 1;
|
179 |
+
|
180 |
+
if [ -e "${final_model_name}.zip" ]; then
|
181 |
+
rm -rf "${final_model_name}_backup.zip"
|
182 |
+
mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
|
183 |
+
fi
|
184 |
+
|
185 |
+
zip -r "${final_model_name}.zip" "${final_model_name}"
|
186 |
+
rm -rf "${final_model_name}"
|
187 |
+
|
188 |
+
fi
|
189 |
+
|
190 |
+
|
191 |
+
if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
|
192 |
+
$verbose && echo "stage 6: clear file_dir"
|
193 |
+
cd "${work_dir}" || exit 1
|
194 |
+
|
195 |
+
rm -rf "${file_dir}";
|
196 |
+
|
197 |
+
fi
|
examples/vm_sound_classification/run_batch.sh
ADDED
@@ -0,0 +1,268 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
|
4 |
+
# sound ch4
|
5 |
+
|
6 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch4 \
|
7 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
8 |
+
#--label_plan 2 \
|
9 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml"
|
10 |
+
#
|
11 |
+
#
|
12 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch4 \
|
13 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
14 |
+
#--label_plan 3 \
|
15 |
+
#--config_file "yaml/conv2d-classifier-3-ch4.yaml"
|
16 |
+
#
|
17 |
+
#
|
18 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch4 \
|
19 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
20 |
+
#--label_plan 4 \
|
21 |
+
#--config_file "yaml/conv2d-classifier-4-ch4.yaml"
|
22 |
+
#
|
23 |
+
#
|
24 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch4 \
|
25 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
26 |
+
#--label_plan 8 \
|
27 |
+
#--config_file "yaml/conv2d-classifier-8-ch4.yaml"
|
28 |
+
|
29 |
+
|
30 |
+
# sound ch8
|
31 |
+
|
32 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch8 \
|
33 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
34 |
+
#--label_plan 2 \
|
35 |
+
#--config_file "yaml/conv2d-classifier-2-ch8.yaml"
|
36 |
+
#
|
37 |
+
#
|
38 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch8 \
|
39 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
40 |
+
#--label_plan 3 \
|
41 |
+
#--config_file "yaml/conv2d-classifier-3-ch8.yaml"
|
42 |
+
#
|
43 |
+
#
|
44 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch8 \
|
45 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
46 |
+
#--label_plan 4 \
|
47 |
+
#--config_file "yaml/conv2d-classifier-4-ch8.yaml"
|
48 |
+
#
|
49 |
+
#
|
50 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch8 \
|
51 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
52 |
+
#--label_plan 8 \
|
53 |
+
#--config_file "yaml/conv2d-classifier-8-ch8.yaml"
|
54 |
+
|
55 |
+
|
56 |
+
# sound ch16
|
57 |
+
|
58 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch16 \
|
59 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
60 |
+
#--label_plan 2 \
|
61 |
+
#--config_file "yaml/conv2d-classifier-2-ch16.yaml"
|
62 |
+
|
63 |
+
|
64 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch16 \
|
65 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
66 |
+
#--label_plan 3 \
|
67 |
+
#--config_file "yaml/conv2d-classifier-3-ch16.yaml"
|
68 |
+
#
|
69 |
+
#
|
70 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch16 \
|
71 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
72 |
+
#--label_plan 4 \
|
73 |
+
#--config_file "yaml/conv2d-classifier-4-ch16.yaml"
|
74 |
+
#
|
75 |
+
#
|
76 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch16 \
|
77 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
78 |
+
#--label_plan 8 \
|
79 |
+
#--config_file "yaml/conv2d-classifier-8-ch16.yaml"
|
80 |
+
|
81 |
+
|
82 |
+
# sound ch32
|
83 |
+
|
84 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch32 \
|
85 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
86 |
+
#--label_plan 2 \
|
87 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml"
|
88 |
+
#
|
89 |
+
#
|
90 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch32 \
|
91 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
92 |
+
#--label_plan 3 \
|
93 |
+
#--config_file "yaml/conv2d-classifier-3-ch32.yaml"
|
94 |
+
#
|
95 |
+
#
|
96 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch32 \
|
97 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
98 |
+
#--label_plan 4 \
|
99 |
+
#--config_file "yaml/conv2d-classifier-4-ch32.yaml"
|
100 |
+
|
101 |
+
|
102 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch32 \
|
103 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
104 |
+
#--label_plan 8 \
|
105 |
+
#--config_file "yaml/conv2d-classifier-8-ch32.yaml"
|
106 |
+
|
107 |
+
|
108 |
+
# pretrained voicemail
|
109 |
+
|
110 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-2-ch4 \
|
111 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
112 |
+
--label_plan 2-voicemail \
|
113 |
+
--config_file "yaml/conv2d-classifier-2-ch4.yaml"
|
114 |
+
|
115 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-2-ch32 \
|
116 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
|
117 |
+
--label_plan 2-voicemail \
|
118 |
+
--config_file "yaml/conv2d-classifier-2-ch32.yaml"
|
119 |
+
|
120 |
+
|
121 |
+
# voicemail ch4
|
122 |
+
|
123 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-ph-2-ch4 \
|
124 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-PH/wav_finished/*/*.wav" \
|
125 |
+
--label_plan 2-voicemail \
|
126 |
+
--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
127 |
+
--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
128 |
+
|
129 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-sg-2-ch4 \
|
130 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-SG/wav_finished/*/*.wav" \
|
131 |
+
#--label_plan 2-voicemail \
|
132 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
133 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
134 |
+
#
|
135 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-us-2-ch4 \
|
136 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-US/wav_finished/*/*.wav" \
|
137 |
+
#--label_plan 2-voicemail \
|
138 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
139 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
140 |
+
#
|
141 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-mx-2-ch4 \
|
142 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-MX/wav_finished/*/*.wav" \
|
143 |
+
#--label_plan 2-voicemail \
|
144 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
145 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
146 |
+
#
|
147 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-pe-2-ch4 \
|
148 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-PE/wav_finished/*/*.wav" \
|
149 |
+
#--label_plan 2-voicemail \
|
150 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
151 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
152 |
+
#
|
153 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-id-id-2-ch4 \
|
154 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/id-ID/wav_finished/*/*.wav" \
|
155 |
+
#--label_plan 2-voicemail \
|
156 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
157 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
158 |
+
#
|
159 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ja-jp-2-ch4 \
|
160 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ja-JP/wav_finished/*/*.wav" \
|
161 |
+
#--label_plan 2-voicemail \
|
162 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
163 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
164 |
+
#
|
165 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ko-kr-2-ch4 \
|
166 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ko-KR/wav_finished/*/*.wav" \
|
167 |
+
#--label_plan 2-voicemail \
|
168 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
169 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
170 |
+
#
|
171 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch4 \
|
172 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
|
173 |
+
#--label_plan 2-voicemail \
|
174 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
175 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
176 |
+
#
|
177 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-pt-br-2-ch4 \
|
178 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/pt-BR/wav_finished/*/*.wav" \
|
179 |
+
#--label_plan 2-voicemail \
|
180 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
181 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
182 |
+
#
|
183 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-th-th-2-ch4 \
|
184 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/th-TH/wav_finished/*/*.wav" \
|
185 |
+
#--label_plan 2-voicemail \
|
186 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
187 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
188 |
+
#
|
189 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-zh-tw-2-ch4 \
|
190 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/zh-TW/wav_finished/*/*.wav" \
|
191 |
+
#--label_plan 2-voicemail \
|
192 |
+
#--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
|
193 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
|
194 |
+
|
195 |
+
|
196 |
+
# voicemail ch32
|
197 |
+
|
198 |
+
sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-ph-2-ch32 \
|
199 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-PH/wav_finished/*/*.wav" \
|
200 |
+
--label_plan 2-voicemail \
|
201 |
+
--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
202 |
+
--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
203 |
+
|
204 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-sg-2-ch32 \
|
205 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-SG/wav_finished/*/*.wav" \
|
206 |
+
#--label_plan 2-voicemail \
|
207 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
208 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
209 |
+
#
|
210 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-us-2-ch32 \
|
211 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-US/wav_finished/*/*.wav" \
|
212 |
+
#--label_plan 2-voicemail \
|
213 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
214 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
215 |
+
#
|
216 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-mx-2-ch32 \
|
217 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-MX/wav_finished/*/*.wav" \
|
218 |
+
#--label_plan 2-voicemail \
|
219 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
220 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
221 |
+
#
|
222 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-pe-2-ch32 \
|
223 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-PE/wav_finished/*/*.wav" \
|
224 |
+
#--label_plan 2-voicemail \
|
225 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
226 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
227 |
+
#
|
228 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-id-id-2-ch32 \
|
229 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/id-ID/wav_finished/*/*.wav" \
|
230 |
+
#--label_plan 2-voicemail \
|
231 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
232 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
233 |
+
#
|
234 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ja-jp-2-ch32 \
|
235 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ja-JP/wav_finished/*/*.wav" \
|
236 |
+
#--label_plan 2-voicemail \
|
237 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
238 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
239 |
+
#
|
240 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ko-kr-2-ch32 \
|
241 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ko-KR/wav_finished/*/*.wav" \
|
242 |
+
#--label_plan 2-voicemail \
|
243 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
244 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
245 |
+
#
|
246 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch32 \
|
247 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
|
248 |
+
#--label_plan 2-voicemail \
|
249 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
250 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
251 |
+
#
|
252 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-pt-br-2-ch32 \
|
253 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/pt-BR/wav_finished/*/*.wav" \
|
254 |
+
#--label_plan 2-voicemail \
|
255 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
256 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
257 |
+
#
|
258 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-th-th-2-ch32 \
|
259 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/th-TH/wav_finished/*/*.wav" \
|
260 |
+
#--label_plan 2-voicemail \
|
261 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
262 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
263 |
+
#
|
264 |
+
#sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-zh-tw-2-ch32 \
|
265 |
+
#--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/zh-TW/wav_finished/*/*.wav" \
|
266 |
+
#--label_plan 2-voicemail \
|
267 |
+
#--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
|
268 |
+
#--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
|
examples/vm_sound_classification/step_1_prepare_data.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from glob import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
22 |
+
parser.add_argument("--filename_patterns", type=str)
|
23 |
+
|
24 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
25 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
26 |
+
|
27 |
+
parser.add_argument("--label_plan", default="4", type=str)
|
28 |
+
|
29 |
+
args = parser.parse_args()
|
30 |
+
return args
|
31 |
+
|
32 |
+
|
33 |
+
def get_dataset(args):
|
34 |
+
filename_patterns = args.filename_patterns
|
35 |
+
filename_patterns = filename_patterns.split(" ")
|
36 |
+
print(filename_patterns)
|
37 |
+
|
38 |
+
file_dir = Path(args.file_dir)
|
39 |
+
file_dir.mkdir(exist_ok=True)
|
40 |
+
|
41 |
+
if args.label_plan == "2-voicemail":
|
42 |
+
label_map = {
|
43 |
+
"bell": "voicemail",
|
44 |
+
"white_noise": "non_voicemail",
|
45 |
+
"low_white_noise": "non_voicemail",
|
46 |
+
"high_white_noise": "non_voicemail",
|
47 |
+
# "music": "non_voicemail",
|
48 |
+
"mute": "non_voicemail",
|
49 |
+
"noise": "non_voicemail",
|
50 |
+
"noise_mute": "non_voicemail",
|
51 |
+
"voice": "non_voicemail",
|
52 |
+
"voicemail": "voicemail",
|
53 |
+
}
|
54 |
+
elif args.label_plan == "2":
|
55 |
+
label_map = {
|
56 |
+
"bell": "non_voice",
|
57 |
+
"white_noise": "non_voice",
|
58 |
+
"low_white_noise": "non_voice",
|
59 |
+
"high_white_noise": "non_voice",
|
60 |
+
"music": "non_voice",
|
61 |
+
"mute": "non_voice",
|
62 |
+
"noise": "non_voice",
|
63 |
+
"noise_mute": "non_voice",
|
64 |
+
"voice": "voice",
|
65 |
+
"voicemail": "voice",
|
66 |
+
}
|
67 |
+
elif args.label_plan == "3":
|
68 |
+
label_map = {
|
69 |
+
"bell": "voicemail",
|
70 |
+
"white_noise": "mute",
|
71 |
+
"low_white_noise": "mute",
|
72 |
+
"high_white_noise": "mute",
|
73 |
+
# "music": "music",
|
74 |
+
"mute": "mute",
|
75 |
+
"noise": "voice_or_noise",
|
76 |
+
"noise_mute": "voice_or_noise",
|
77 |
+
"voice": "voice_or_noise",
|
78 |
+
"voicemail": "voicemail",
|
79 |
+
}
|
80 |
+
elif args.label_plan == "4":
|
81 |
+
label_map = {
|
82 |
+
"bell": "voicemail",
|
83 |
+
"white_noise": "mute",
|
84 |
+
"low_white_noise": "mute",
|
85 |
+
"high_white_noise": "mute",
|
86 |
+
# "music": "music",
|
87 |
+
"mute": "mute",
|
88 |
+
"noise": "noise",
|
89 |
+
"noise_mute": "noise",
|
90 |
+
"voice": "voice",
|
91 |
+
"voicemail": "voicemail",
|
92 |
+
}
|
93 |
+
elif args.label_plan == "8":
|
94 |
+
label_map = {
|
95 |
+
"bell": "bell",
|
96 |
+
"white_noise": "white_noise",
|
97 |
+
"low_white_noise": "white_noise",
|
98 |
+
"high_white_noise": "white_noise",
|
99 |
+
"music": "music",
|
100 |
+
"mute": "mute",
|
101 |
+
"noise": "noise",
|
102 |
+
"noise_mute": "noise_mute",
|
103 |
+
"voice": "voice",
|
104 |
+
"voicemail": "voicemail",
|
105 |
+
}
|
106 |
+
else:
|
107 |
+
raise AssertionError
|
108 |
+
|
109 |
+
result = list()
|
110 |
+
for filename_pattern in filename_patterns:
|
111 |
+
filename_list = glob(filename_pattern)
|
112 |
+
for filename in tqdm(filename_list):
|
113 |
+
filename = Path(filename)
|
114 |
+
sample_rate, signal = wavfile.read(filename.as_posix())
|
115 |
+
if len(signal) < sample_rate * 2:
|
116 |
+
continue
|
117 |
+
|
118 |
+
folder = filename.parts[-2]
|
119 |
+
country = filename.parts[-4]
|
120 |
+
|
121 |
+
if folder not in label_map.keys():
|
122 |
+
continue
|
123 |
+
|
124 |
+
labels = label_map[folder]
|
125 |
+
|
126 |
+
random1 = random.random()
|
127 |
+
random2 = random.random()
|
128 |
+
|
129 |
+
result.append({
|
130 |
+
"filename": filename,
|
131 |
+
"folder": folder,
|
132 |
+
"category": country,
|
133 |
+
"labels": labels,
|
134 |
+
"random1": random1,
|
135 |
+
"random2": random2,
|
136 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
137 |
+
})
|
138 |
+
|
139 |
+
df = pd.DataFrame(result)
|
140 |
+
pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count")
|
141 |
+
print(pivot_table)
|
142 |
+
|
143 |
+
df = df.sort_values(by=["random1"], ascending=False)
|
144 |
+
df.to_excel(
|
145 |
+
file_dir / "dataset.xlsx",
|
146 |
+
index=False,
|
147 |
+
# encoding="utf_8_sig"
|
148 |
+
)
|
149 |
+
|
150 |
+
return
|
151 |
+
|
152 |
+
|
153 |
+
def split_dataset(args):
|
154 |
+
"""分割训练集, 测试集"""
|
155 |
+
file_dir = Path(args.file_dir)
|
156 |
+
file_dir.mkdir(exist_ok=True)
|
157 |
+
|
158 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
159 |
+
|
160 |
+
train = list()
|
161 |
+
test = list()
|
162 |
+
|
163 |
+
for i, row in df.iterrows():
|
164 |
+
flag = row["flag"]
|
165 |
+
if flag == "TRAIN":
|
166 |
+
train.append(row)
|
167 |
+
else:
|
168 |
+
test.append(row)
|
169 |
+
|
170 |
+
train = pd.DataFrame(train)
|
171 |
+
train.to_excel(
|
172 |
+
args.train_dataset,
|
173 |
+
index=False,
|
174 |
+
# encoding="utf_8_sig"
|
175 |
+
)
|
176 |
+
test = pd.DataFrame(test)
|
177 |
+
test.to_excel(
|
178 |
+
args.valid_dataset,
|
179 |
+
index=False,
|
180 |
+
# encoding="utf_8_sig"
|
181 |
+
)
|
182 |
+
|
183 |
+
return
|
184 |
+
|
185 |
+
|
186 |
+
def main():
|
187 |
+
args = get_args()
|
188 |
+
get_dataset(args)
|
189 |
+
split_dataset(args)
|
190 |
+
return
|
191 |
+
|
192 |
+
|
193 |
+
if __name__ == "__main__":
|
194 |
+
main()
|
examples/vm_sound_classification/step_2_make_vocabulary.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
19 |
+
|
20 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
21 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
args = get_args()
|
29 |
+
|
30 |
+
train_dataset = pd.read_excel(args.train_dataset)
|
31 |
+
valid_dataset = pd.read_excel(args.valid_dataset)
|
32 |
+
|
33 |
+
vocabulary = Vocabulary()
|
34 |
+
|
35 |
+
# train
|
36 |
+
for i, row in train_dataset.iterrows():
|
37 |
+
label = row["labels"]
|
38 |
+
vocabulary.add_token_to_namespace(label, namespace="labels")
|
39 |
+
|
40 |
+
# valid
|
41 |
+
for i, row in valid_dataset.iterrows():
|
42 |
+
label = row["labels"]
|
43 |
+
vocabulary.add_token_to_namespace(label, namespace="labels")
|
44 |
+
|
45 |
+
vocabulary.save_to_files(args.vocabulary_dir)
|
46 |
+
|
47 |
+
return
|
48 |
+
|
49 |
+
|
50 |
+
if __name__ == "__main__":
|
51 |
+
main()
|
examples/vm_sound_classification/step_3_train_model.py
ADDED
@@ -0,0 +1,367 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import random
|
12 |
+
import sys
|
13 |
+
import shutil
|
14 |
+
import tempfile
|
15 |
+
from typing import List
|
16 |
+
import zipfile
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
27 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
28 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
29 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
30 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
31 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
|
32 |
+
|
33 |
+
|
34 |
+
def get_args():
|
35 |
+
parser = argparse.ArgumentParser()
|
36 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
37 |
+
|
38 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
39 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
40 |
+
|
41 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
42 |
+
|
43 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
44 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
45 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
46 |
+
parser.add_argument("--patience", default=5, type=int)
|
47 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
48 |
+
parser.add_argument("--seed", default=0, type=int)
|
49 |
+
|
50 |
+
parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str)
|
51 |
+
parser.add_argument(
|
52 |
+
"--pretrained_model",
|
53 |
+
# default=(project_path / "trained_models/voicemail-en-sg-2-ch4.zip").as_posix(),
|
54 |
+
default="null",
|
55 |
+
type=str
|
56 |
+
)
|
57 |
+
|
58 |
+
args = parser.parse_args()
|
59 |
+
return args
|
60 |
+
|
61 |
+
|
62 |
+
def logging_config(file_dir: str):
|
63 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
64 |
+
|
65 |
+
logging.basicConfig(format=fmt,
|
66 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
67 |
+
level=logging.DEBUG)
|
68 |
+
file_handler = TimedRotatingFileHandler(
|
69 |
+
filename=os.path.join(file_dir, "main.log"),
|
70 |
+
encoding="utf-8",
|
71 |
+
when="D",
|
72 |
+
interval=1,
|
73 |
+
backupCount=7
|
74 |
+
)
|
75 |
+
file_handler.setLevel(logging.INFO)
|
76 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
77 |
+
logger = logging.getLogger(__name__)
|
78 |
+
logger.addHandler(file_handler)
|
79 |
+
|
80 |
+
return logger
|
81 |
+
|
82 |
+
|
83 |
+
class CollateFunction(object):
|
84 |
+
def __init__(self):
|
85 |
+
pass
|
86 |
+
|
87 |
+
def __call__(self, batch: List[dict]):
|
88 |
+
array_list = list()
|
89 |
+
label_list = list()
|
90 |
+
for sample in batch:
|
91 |
+
array = sample["waveform"]
|
92 |
+
label = sample["label"]
|
93 |
+
|
94 |
+
l = len(array)
|
95 |
+
if l < 16000:
|
96 |
+
delta = int(16000 - l)
|
97 |
+
array = np.concatenate([array, np.zeros(shape=(delta,), dtype=np.float32)], axis=-1)
|
98 |
+
if l > 16000:
|
99 |
+
array = array[:16000]
|
100 |
+
|
101 |
+
array_list.append(array)
|
102 |
+
label_list.append(label)
|
103 |
+
|
104 |
+
array_list = torch.stack(array_list)
|
105 |
+
label_list = torch.stack(label_list)
|
106 |
+
return array_list, label_list
|
107 |
+
|
108 |
+
|
109 |
+
collate_fn = CollateFunction()
|
110 |
+
|
111 |
+
|
112 |
+
def main():
|
113 |
+
args = get_args()
|
114 |
+
|
115 |
+
serialization_dir = Path(args.serialization_dir)
|
116 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
117 |
+
|
118 |
+
logger = logging_config(serialization_dir)
|
119 |
+
|
120 |
+
random.seed(args.seed)
|
121 |
+
np.random.seed(args.seed)
|
122 |
+
torch.manual_seed(args.seed)
|
123 |
+
logger.info("set seed: {}".format(args.seed))
|
124 |
+
|
125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
126 |
+
n_gpu = torch.cuda.device_count()
|
127 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
128 |
+
|
129 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
130 |
+
|
131 |
+
# datasets
|
132 |
+
logger.info("prepare datasets")
|
133 |
+
train_dataset = WaveClassifierExcelDataset(
|
134 |
+
vocab=vocabulary,
|
135 |
+
excel_file=args.train_dataset,
|
136 |
+
category=None,
|
137 |
+
category_field="category",
|
138 |
+
label_field="labels",
|
139 |
+
expected_sample_rate=8000,
|
140 |
+
max_wave_value=32768.0,
|
141 |
+
)
|
142 |
+
valid_dataset = WaveClassifierExcelDataset(
|
143 |
+
vocab=vocabulary,
|
144 |
+
excel_file=args.valid_dataset,
|
145 |
+
category=None,
|
146 |
+
category_field="category",
|
147 |
+
label_field="labels",
|
148 |
+
expected_sample_rate=8000,
|
149 |
+
max_wave_value=32768.0,
|
150 |
+
)
|
151 |
+
train_data_loader = DataLoader(
|
152 |
+
dataset=train_dataset,
|
153 |
+
batch_size=args.batch_size,
|
154 |
+
shuffle=True,
|
155 |
+
# Linux 系统中可以使用多个子进程加��数据, 而在 Windows 系统中不能.
|
156 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
157 |
+
collate_fn=collate_fn,
|
158 |
+
pin_memory=False,
|
159 |
+
# prefetch_factor=64,
|
160 |
+
)
|
161 |
+
valid_data_loader = DataLoader(
|
162 |
+
dataset=valid_dataset,
|
163 |
+
batch_size=args.batch_size,
|
164 |
+
shuffle=True,
|
165 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
166 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
167 |
+
collate_fn=collate_fn,
|
168 |
+
pin_memory=False,
|
169 |
+
# prefetch_factor=64,
|
170 |
+
)
|
171 |
+
|
172 |
+
# models
|
173 |
+
logger.info(f"prepare models. config_file: {args.config_file}")
|
174 |
+
config = CnnAudioClassifierConfig.from_pretrained(
|
175 |
+
pretrained_model_name_or_path=args.config_file,
|
176 |
+
# num_labels=vocabulary.get_vocab_size(namespace="labels")
|
177 |
+
)
|
178 |
+
if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"):
|
179 |
+
raise AssertionError("expected num labels: {} instead of {}.".format(
|
180 |
+
vocabulary.get_vocab_size(namespace="labels"),
|
181 |
+
config.cls_head_param["num_labels"],
|
182 |
+
))
|
183 |
+
model = WaveClassifierPretrainedModel(
|
184 |
+
config=config,
|
185 |
+
)
|
186 |
+
|
187 |
+
if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
|
188 |
+
logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
|
189 |
+
pretrained_model = Path(args.pretrained_model)
|
190 |
+
with zipfile.ZipFile(pretrained_model.as_posix(), "r") as f_zip:
|
191 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
192 |
+
# print(out_root.as_posix())
|
193 |
+
if out_root.exists():
|
194 |
+
shutil.rmtree(out_root.as_posix())
|
195 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
196 |
+
f_zip.extractall(path=out_root)
|
197 |
+
|
198 |
+
tgt_path = out_root / pretrained_model.stem
|
199 |
+
model_pt_file = tgt_path / "model.pt"
|
200 |
+
with open(model_pt_file, "rb") as f:
|
201 |
+
state_dict = torch.load(f, map_location="cpu")
|
202 |
+
model.load_state_dict(state_dict=state_dict)
|
203 |
+
|
204 |
+
model.to(device)
|
205 |
+
model.train()
|
206 |
+
|
207 |
+
# optimizer
|
208 |
+
logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
|
209 |
+
param_optimizer = model.parameters()
|
210 |
+
optimizer = torch.optim.Adam(
|
211 |
+
param_optimizer,
|
212 |
+
lr=args.learning_rate,
|
213 |
+
)
|
214 |
+
# lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
215 |
+
# optimizer,
|
216 |
+
# step_size=2000
|
217 |
+
# )
|
218 |
+
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
|
219 |
+
optimizer,
|
220 |
+
milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
|
221 |
+
)
|
222 |
+
focal_loss = FocalLoss(
|
223 |
+
num_classes=vocabulary.get_vocab_size(namespace="labels"),
|
224 |
+
reduction="mean",
|
225 |
+
)
|
226 |
+
categorical_accuracy = CategoricalAccuracy()
|
227 |
+
|
228 |
+
# training loop
|
229 |
+
logger.info("training")
|
230 |
+
|
231 |
+
training_loss = 10000000000
|
232 |
+
training_accuracy = 0.
|
233 |
+
evaluation_loss = 10000000000
|
234 |
+
evaluation_accuracy = 0.
|
235 |
+
|
236 |
+
model_list = list()
|
237 |
+
best_idx_epoch = None
|
238 |
+
best_accuracy = None
|
239 |
+
patience_count = 0
|
240 |
+
|
241 |
+
for idx_epoch in range(args.max_epochs):
|
242 |
+
categorical_accuracy.reset()
|
243 |
+
total_loss = 0.
|
244 |
+
total_examples = 0.
|
245 |
+
progress_bar = tqdm(
|
246 |
+
total=len(train_data_loader),
|
247 |
+
desc="Training; epoch: {}".format(idx_epoch),
|
248 |
+
)
|
249 |
+
for batch in train_data_loader:
|
250 |
+
input_ids, label_ids = batch
|
251 |
+
input_ids = input_ids.to(device)
|
252 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
253 |
+
|
254 |
+
logits = model.forward(input_ids)
|
255 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
256 |
+
categorical_accuracy(logits, label_ids)
|
257 |
+
|
258 |
+
total_loss += loss.item()
|
259 |
+
total_examples += input_ids.size(0)
|
260 |
+
|
261 |
+
optimizer.zero_grad()
|
262 |
+
loss.backward()
|
263 |
+
optimizer.step()
|
264 |
+
lr_scheduler.step()
|
265 |
+
|
266 |
+
training_loss = total_loss / total_examples
|
267 |
+
training_loss = round(training_loss, 4)
|
268 |
+
training_accuracy = categorical_accuracy.get_metric()["accuracy"]
|
269 |
+
training_accuracy = round(training_accuracy, 4)
|
270 |
+
|
271 |
+
progress_bar.update(1)
|
272 |
+
progress_bar.set_postfix({
|
273 |
+
"training_loss": training_loss,
|
274 |
+
"training_accuracy": training_accuracy,
|
275 |
+
})
|
276 |
+
|
277 |
+
categorical_accuracy.reset()
|
278 |
+
total_loss = 0.
|
279 |
+
total_examples = 0.
|
280 |
+
progress_bar = tqdm(
|
281 |
+
total=len(valid_data_loader),
|
282 |
+
desc="Evaluation; epoch: {}".format(idx_epoch),
|
283 |
+
)
|
284 |
+
for batch in valid_data_loader:
|
285 |
+
input_ids, label_ids = batch
|
286 |
+
input_ids = input_ids.to(device)
|
287 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
288 |
+
|
289 |
+
with torch.no_grad():
|
290 |
+
logits = model.forward(input_ids)
|
291 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
292 |
+
categorical_accuracy(logits, label_ids)
|
293 |
+
|
294 |
+
total_loss += loss.item()
|
295 |
+
total_examples += input_ids.size(0)
|
296 |
+
|
297 |
+
evaluation_loss = total_loss / total_examples
|
298 |
+
evaluation_loss = round(evaluation_loss, 4)
|
299 |
+
evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"]
|
300 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
301 |
+
|
302 |
+
progress_bar.update(1)
|
303 |
+
progress_bar.set_postfix({
|
304 |
+
"evaluation_loss": evaluation_loss,
|
305 |
+
"evaluation_accuracy": evaluation_accuracy,
|
306 |
+
})
|
307 |
+
|
308 |
+
# save path
|
309 |
+
epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
|
310 |
+
epoch_dir.mkdir(parents=True, exist_ok=False)
|
311 |
+
|
312 |
+
# save models
|
313 |
+
model.save_pretrained(epoch_dir.as_posix())
|
314 |
+
|
315 |
+
model_list.append(epoch_dir)
|
316 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
317 |
+
model_to_delete: Path = model_list.pop(0)
|
318 |
+
shutil.rmtree(model_to_delete.as_posix())
|
319 |
+
|
320 |
+
# save metric
|
321 |
+
if best_accuracy is None:
|
322 |
+
best_idx_epoch = idx_epoch
|
323 |
+
best_accuracy = evaluation_accuracy
|
324 |
+
elif evaluation_accuracy > best_accuracy:
|
325 |
+
best_idx_epoch = idx_epoch
|
326 |
+
best_accuracy = evaluation_accuracy
|
327 |
+
else:
|
328 |
+
pass
|
329 |
+
|
330 |
+
metrics = {
|
331 |
+
"idx_epoch": idx_epoch,
|
332 |
+
"best_idx_epoch": best_idx_epoch,
|
333 |
+
"best_accuracy": best_accuracy,
|
334 |
+
"training_loss": training_loss,
|
335 |
+
"training_accuracy": training_accuracy,
|
336 |
+
"evaluation_loss": evaluation_loss,
|
337 |
+
"evaluation_accuracy": evaluation_accuracy,
|
338 |
+
"learning_rate": optimizer.param_groups[0]['lr'],
|
339 |
+
}
|
340 |
+
metrics_filename = epoch_dir / "metrics_epoch.json"
|
341 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
342 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
343 |
+
|
344 |
+
# save best
|
345 |
+
best_dir = serialization_dir / "best"
|
346 |
+
if best_idx_epoch == idx_epoch:
|
347 |
+
if best_dir.exists():
|
348 |
+
shutil.rmtree(best_dir)
|
349 |
+
shutil.copytree(epoch_dir, best_dir)
|
350 |
+
|
351 |
+
# early stop
|
352 |
+
early_stop_flag = False
|
353 |
+
if best_idx_epoch == idx_epoch:
|
354 |
+
patience_count = 0
|
355 |
+
else:
|
356 |
+
patience_count += 1
|
357 |
+
if patience_count >= args.patience:
|
358 |
+
early_stop_flag = True
|
359 |
+
|
360 |
+
# early stop
|
361 |
+
if early_stop_flag:
|
362 |
+
break
|
363 |
+
return
|
364 |
+
|
365 |
+
|
366 |
+
if __name__ == "__main__":
|
367 |
+
main()
|
examples/vm_sound_classification/step_4_evaluation_model.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
from scipy.io import wavfile
|
20 |
+
import torch
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
24 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
25 |
+
|
26 |
+
|
27 |
+
def get_args():
|
28 |
+
parser = argparse.ArgumentParser()
|
29 |
+
parser.add_argument("--dataset", default="dataset.xlsx", type=str)
|
30 |
+
|
31 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
32 |
+
parser.add_argument("--model_dir", default="best", type=str)
|
33 |
+
|
34 |
+
parser.add_argument("--output_file", default="evaluation.xlsx", type=str)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def logging_config():
|
41 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
42 |
+
|
43 |
+
logging.basicConfig(format=fmt,
|
44 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
45 |
+
level=logging.DEBUG)
|
46 |
+
stream_handler = logging.StreamHandler()
|
47 |
+
stream_handler.setLevel(logging.INFO)
|
48 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
49 |
+
|
50 |
+
logger = logging.getLogger(__name__)
|
51 |
+
|
52 |
+
return logger
|
53 |
+
|
54 |
+
|
55 |
+
def main():
|
56 |
+
args = get_args()
|
57 |
+
|
58 |
+
logger = logging_config()
|
59 |
+
|
60 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
61 |
+
n_gpu = torch.cuda.device_count()
|
62 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
63 |
+
|
64 |
+
logger.info("prepare vocabulary, model")
|
65 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
66 |
+
|
67 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
68 |
+
pretrained_model_name_or_path=args.model_dir,
|
69 |
+
)
|
70 |
+
model.to(device)
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
logger.info("read excel")
|
74 |
+
df = pd.read_excel(args.dataset)
|
75 |
+
result = list()
|
76 |
+
|
77 |
+
total_correct = 0
|
78 |
+
total_examples = 0
|
79 |
+
|
80 |
+
progress_bar = tqdm(total=len(df), desc="Evaluation")
|
81 |
+
for i, row in df.iterrows():
|
82 |
+
filename = row["filename"]
|
83 |
+
ground_true = row["labels"]
|
84 |
+
|
85 |
+
sample_rate, waveform = wavfile.read(filename)
|
86 |
+
waveform = waveform / (1 << 15)
|
87 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
88 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
89 |
+
waveform = waveform.to(device)
|
90 |
+
|
91 |
+
with torch.no_grad():
|
92 |
+
logits = model.forward(waveform)
|
93 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
94 |
+
label_idx = torch.argmax(probs, dim=-1)
|
95 |
+
|
96 |
+
label_idx = label_idx.cpu()
|
97 |
+
probs = probs.cpu()
|
98 |
+
|
99 |
+
label_idx = label_idx.numpy()[0]
|
100 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
101 |
+
prob = probs[0][label_idx].numpy()
|
102 |
+
|
103 |
+
correct = 1 if label_str == ground_true else 0
|
104 |
+
row_ = dict(row)
|
105 |
+
row_["predict"] = label_str
|
106 |
+
row_["prob"] = prob
|
107 |
+
row_["correct"] = correct
|
108 |
+
result.append(row_)
|
109 |
+
|
110 |
+
total_examples += 1
|
111 |
+
total_correct += correct
|
112 |
+
accuracy = total_correct / total_examples
|
113 |
+
|
114 |
+
progress_bar.update(1)
|
115 |
+
progress_bar.set_postfix({
|
116 |
+
"accuracy": accuracy,
|
117 |
+
})
|
118 |
+
|
119 |
+
result = pd.DataFrame(result)
|
120 |
+
result.to_excel(
|
121 |
+
args.output_file,
|
122 |
+
index=False
|
123 |
+
)
|
124 |
+
return
|
125 |
+
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
main()
|
examples/vm_sound_classification/step_5_export_models.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import numpy as np
|
19 |
+
import torch
|
20 |
+
|
21 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
22 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
23 |
+
|
24 |
+
|
25 |
+
def get_args():
|
26 |
+
parser = argparse.ArgumentParser()
|
27 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
28 |
+
parser.add_argument("--model_dir", default="best", type=str)
|
29 |
+
|
30 |
+
parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
|
31 |
+
|
32 |
+
args = parser.parse_args()
|
33 |
+
return args
|
34 |
+
|
35 |
+
|
36 |
+
def logging_config():
|
37 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
38 |
+
|
39 |
+
logging.basicConfig(format=fmt,
|
40 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
41 |
+
level=logging.DEBUG)
|
42 |
+
stream_handler = logging.StreamHandler()
|
43 |
+
stream_handler.setLevel(logging.INFO)
|
44 |
+
stream_handler.setFormatter(logging.Formatter(fmt))
|
45 |
+
|
46 |
+
logger = logging.getLogger(__name__)
|
47 |
+
|
48 |
+
return logger
|
49 |
+
|
50 |
+
|
51 |
+
def main():
|
52 |
+
args = get_args()
|
53 |
+
|
54 |
+
serialization_dir = Path(args.serialization_dir)
|
55 |
+
|
56 |
+
logger = logging_config()
|
57 |
+
|
58 |
+
logger.info("export models on CPU")
|
59 |
+
device = torch.device("cpu")
|
60 |
+
|
61 |
+
logger.info("prepare vocabulary, model")
|
62 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
63 |
+
|
64 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
65 |
+
pretrained_model_name_or_path=args.model_dir,
|
66 |
+
num_labels=vocabulary.get_vocab_size(namespace="labels")
|
67 |
+
)
|
68 |
+
model.to(device)
|
69 |
+
model.eval()
|
70 |
+
|
71 |
+
waveform = 0 + 25 * np.random.randn(16000,)
|
72 |
+
waveform = np.array(waveform, dtype=np.int16)
|
73 |
+
waveform = waveform / (1 << 15)
|
74 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
75 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
76 |
+
waveform = waveform.to(device)
|
77 |
+
|
78 |
+
logger.info("export jit models")
|
79 |
+
example_inputs = (waveform,)
|
80 |
+
|
81 |
+
# trace model
|
82 |
+
trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
|
83 |
+
trace_model.save(serialization_dir / "trace_model.zip")
|
84 |
+
|
85 |
+
# quantization trace model (not work on GPU)
|
86 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
87 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
88 |
+
)
|
89 |
+
trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
|
90 |
+
trace_quant_model.save(serialization_dir / "trace_quant_model.zip")
|
91 |
+
|
92 |
+
# script model
|
93 |
+
script_model = torch.jit.script(obj=model)
|
94 |
+
script_model.save(serialization_dir / "script_model.zip")
|
95 |
+
|
96 |
+
# quantization script model (not work on GPU)
|
97 |
+
quantized_model = torch.quantization.quantize_dynamic(
|
98 |
+
model, {torch.nn.Linear}, dtype=torch.qint8
|
99 |
+
)
|
100 |
+
script_quant_model = torch.jit.script(quantized_model)
|
101 |
+
script_quant_model.save(serialization_dir / "script_quant_model.zip")
|
102 |
+
return
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == '__main__':
|
106 |
+
main()
|
examples/vm_sound_classification/step_6_infer.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
from scipy.io import wavfile
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import project_path
|
18 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument(
|
24 |
+
"--model_file",
|
25 |
+
default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
|
26 |
+
type=str
|
27 |
+
)
|
28 |
+
parser.add_argument(
|
29 |
+
"--wav_file",
|
30 |
+
default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
|
31 |
+
type=str
|
32 |
+
)
|
33 |
+
|
34 |
+
parser.add_argument("--device", default="cpu", type=str)
|
35 |
+
|
36 |
+
args = parser.parse_args()
|
37 |
+
return args
|
38 |
+
|
39 |
+
|
40 |
+
def main():
|
41 |
+
args = get_args()
|
42 |
+
|
43 |
+
model_file = Path(args.model_file)
|
44 |
+
|
45 |
+
device = torch.device(args.device)
|
46 |
+
|
47 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
48 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
49 |
+
print(out_root.as_posix())
|
50 |
+
if out_root.exists():
|
51 |
+
shutil.rmtree(out_root.as_posix())
|
52 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
53 |
+
f_zip.extractall(path=out_root)
|
54 |
+
|
55 |
+
tgt_path = out_root / model_file.stem
|
56 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
57 |
+
vocab_path = tgt_path / "vocabulary"
|
58 |
+
|
59 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
60 |
+
model = torch.jit.load(f)
|
61 |
+
model.to(device)
|
62 |
+
model.eval()
|
63 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
64 |
+
|
65 |
+
# infer
|
66 |
+
sample_rate, waveform = wavfile.read(args.wav_file)
|
67 |
+
waveform = waveform[:16000]
|
68 |
+
waveform = waveform / (1 << 15)
|
69 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
70 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
71 |
+
waveform = waveform.to(device)
|
72 |
+
|
73 |
+
with torch.no_grad():
|
74 |
+
logits = model.forward(waveform)
|
75 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
76 |
+
label_idx = torch.argmax(probs, dim=-1)
|
77 |
+
|
78 |
+
label_idx = label_idx.cpu()
|
79 |
+
probs = probs.cpu()
|
80 |
+
|
81 |
+
label_idx = label_idx.numpy()[0]
|
82 |
+
prob = probs.numpy()[0][label_idx]
|
83 |
+
|
84 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
85 |
+
print(label_str)
|
86 |
+
print(prob)
|
87 |
+
return
|
88 |
+
|
89 |
+
|
90 |
+
if __name__ == '__main__':
|
91 |
+
main()
|
examples/vm_sound_classification/step_7_test_model.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
from scipy.io import wavfile
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import project_path
|
18 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
|
20 |
+
|
21 |
+
|
22 |
+
def get_args():
|
23 |
+
parser = argparse.ArgumentParser()
|
24 |
+
parser.add_argument(
|
25 |
+
"--model_file",
|
26 |
+
default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
|
27 |
+
type=str
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--wav_file",
|
31 |
+
default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
|
32 |
+
type=str
|
33 |
+
)
|
34 |
+
|
35 |
+
parser.add_argument("--device", default="cpu", type=str)
|
36 |
+
|
37 |
+
args = parser.parse_args()
|
38 |
+
return args
|
39 |
+
|
40 |
+
|
41 |
+
def main():
|
42 |
+
args = get_args()
|
43 |
+
|
44 |
+
model_file = Path(args.model_file)
|
45 |
+
|
46 |
+
device = torch.device(args.device)
|
47 |
+
|
48 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
49 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
50 |
+
print(out_root)
|
51 |
+
if out_root.exists():
|
52 |
+
shutil.rmtree(out_root.as_posix())
|
53 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
54 |
+
f_zip.extractall(path=out_root)
|
55 |
+
|
56 |
+
tgt_path = out_root / model_file.stem
|
57 |
+
vocab_path = tgt_path / "vocabulary"
|
58 |
+
|
59 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
60 |
+
|
61 |
+
model = WaveClassifierPretrainedModel.from_pretrained(
|
62 |
+
pretrained_model_name_or_path=tgt_path.as_posix(),
|
63 |
+
)
|
64 |
+
model.to(device)
|
65 |
+
model.eval()
|
66 |
+
|
67 |
+
# infer
|
68 |
+
sample_rate, waveform = wavfile.read(args.wav_file)
|
69 |
+
waveform = waveform[:16000]
|
70 |
+
waveform = waveform / (1 << 15)
|
71 |
+
waveform = torch.tensor(waveform, dtype=torch.float32)
|
72 |
+
waveform = torch.unsqueeze(waveform, dim=0)
|
73 |
+
waveform = waveform.to(device)
|
74 |
+
print(waveform.shape)
|
75 |
+
with torch.no_grad():
|
76 |
+
logits = model.forward(waveform)
|
77 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
78 |
+
label_idx = torch.argmax(probs, dim=-1)
|
79 |
+
|
80 |
+
label_idx = label_idx.cpu()
|
81 |
+
probs = probs.cpu()
|
82 |
+
|
83 |
+
label_idx = label_idx.numpy()[0]
|
84 |
+
prob = probs.numpy()[0][label_idx]
|
85 |
+
|
86 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
87 |
+
print(label_str)
|
88 |
+
print(prob)
|
89 |
+
return
|
90 |
+
|
91 |
+
|
92 |
+
if __name__ == '__main__':
|
93 |
+
main()
|
examples/vm_sound_classification/stop.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch16.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 16
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 16
|
23 |
+
out_channels: 16
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 16
|
30 |
+
out_channels: 16
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 432
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 2
|
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch32.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 32
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 32
|
23 |
+
out_channels: 32
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 32
|
30 |
+
out_channels: 32
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 864
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 2
|
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch4.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 4
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 4
|
23 |
+
out_channels: 4
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 4
|
30 |
+
out_channels: 4
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 108
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 2
|
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch8.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 8
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 8
|
23 |
+
out_channels: 8
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 8
|
30 |
+
out_channels: 8
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 216
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 2
|
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch16.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 16
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 16
|
23 |
+
out_channels: 16
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 16
|
30 |
+
out_channels: 16
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 432
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 3
|
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch32.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 32
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 32
|
23 |
+
out_channels: 32
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 32
|
30 |
+
out_channels: 32
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 864
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 3
|
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch4.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 4
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 4
|
23 |
+
out_channels: 4
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 4
|
30 |
+
out_channels: 4
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 108
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 3
|
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch8.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 8
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 8
|
23 |
+
out_channels: 8
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 8
|
30 |
+
out_channels: 8
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 216
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 3
|
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch16.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 16
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 16
|
23 |
+
out_channels: 16
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 16
|
30 |
+
out_channels: 16
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 432
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 4
|
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch32.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 32
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 32
|
23 |
+
out_channels: 32
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 32
|
30 |
+
out_channels: 32
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 864
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 4
|
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch4.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 4
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 4
|
23 |
+
out_channels: 4
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 4
|
30 |
+
out_channels: 4
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 108
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 4
|
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch8.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 8
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 8
|
23 |
+
out_channels: 8
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 8
|
30 |
+
out_channels: 8
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 216
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 4
|
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch16.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 16
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 16
|
23 |
+
out_channels: 16
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 16
|
30 |
+
out_channels: 16
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 432
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 8
|
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch32.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 32
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 32
|
23 |
+
out_channels: 32
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 32
|
30 |
+
out_channels: 32
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 864
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 8
|
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch4.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 4
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 4
|
23 |
+
out_channels: 4
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 4
|
30 |
+
out_channels: 4
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 108
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 8
|
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch8.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_name: "cnn_audio_classifier"
|
2 |
+
|
3 |
+
mel_spectrogram_param:
|
4 |
+
sample_rate: 8000
|
5 |
+
n_fft: 512
|
6 |
+
win_length: 200
|
7 |
+
hop_length: 80
|
8 |
+
f_min: 10
|
9 |
+
f_max: 3800
|
10 |
+
window_fn: hamming
|
11 |
+
n_mels: 80
|
12 |
+
|
13 |
+
conv2d_block_param_list:
|
14 |
+
- batch_norm: true
|
15 |
+
in_channels: 1
|
16 |
+
out_channels: 8
|
17 |
+
kernel_size: 3
|
18 |
+
stride: 1
|
19 |
+
dilation: 3
|
20 |
+
activation: relu
|
21 |
+
dropout: 0.1
|
22 |
+
- in_channels: 8
|
23 |
+
out_channels: 8
|
24 |
+
kernel_size: 5
|
25 |
+
stride: 2
|
26 |
+
dilation: 3
|
27 |
+
activation: relu
|
28 |
+
dropout: 0.1
|
29 |
+
- in_channels: 8
|
30 |
+
out_channels: 8
|
31 |
+
kernel_size: 3
|
32 |
+
stride: 1
|
33 |
+
dilation: 2
|
34 |
+
activation: relu
|
35 |
+
dropout: 0.1
|
36 |
+
|
37 |
+
cls_head_param:
|
38 |
+
input_dim: 216
|
39 |
+
num_layers: 2
|
40 |
+
hidden_dims:
|
41 |
+
- 128
|
42 |
+
- 32
|
43 |
+
activations: relu
|
44 |
+
dropout: 0.1
|
45 |
+
num_labels: 8
|
examples/vm_sound_classification8/requirements.txt
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.10.1
|
2 |
+
torchaudio==0.10.1
|
3 |
+
fsspec==2022.1.0
|
4 |
+
librosa==0.9.2
|
5 |
+
pandas==1.1.5
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.64.1
|
9 |
+
overrides==1.9.0
|
examples/vm_sound_classification8/run.sh
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
: <<'END'
|
4 |
+
|
5 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
6 |
+
--filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
|
7 |
+
E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
|
8 |
+
|
9 |
+
|
10 |
+
sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
11 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
12 |
+
|
13 |
+
sh run.sh --stage 4 --stop_stage 4 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
14 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
15 |
+
|
16 |
+
sh run.sh --stage 4 --stop_stage 4 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8 \
|
17 |
+
--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
18 |
+
|
19 |
+
|
20 |
+
"
|
21 |
+
|
22 |
+
END
|
23 |
+
|
24 |
+
|
25 |
+
# sh run.sh --stage -1 --stop_stage 9
|
26 |
+
# sh run.sh --stage -1 --stop_stage 5 --system_version centos --file_folder_name task_cnn_voicemail_id_id --final_model_name cnn_voicemail_id_id
|
27 |
+
# sh run.sh --stage 3 --stop_stage 4
|
28 |
+
# sh run.sh --stage 4 --stop_stage 4
|
29 |
+
# sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name task_cnn_voicemail_id_id
|
30 |
+
|
31 |
+
# params
|
32 |
+
system_version="windows";
|
33 |
+
verbose=true;
|
34 |
+
stage=0 # start from 0 if you need to start from data preparation
|
35 |
+
stop_stage=9
|
36 |
+
|
37 |
+
work_dir="$(pwd)"
|
38 |
+
file_folder_name=file_folder_name
|
39 |
+
final_model_name=final_model_name
|
40 |
+
filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
|
41 |
+
nohup_name=nohup.out
|
42 |
+
|
43 |
+
country=en-US
|
44 |
+
|
45 |
+
# model params
|
46 |
+
batch_size=64
|
47 |
+
max_epochs=200
|
48 |
+
save_top_k=10
|
49 |
+
patience=5
|
50 |
+
|
51 |
+
|
52 |
+
# parse options
|
53 |
+
while true; do
|
54 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
55 |
+
case "$1" in
|
56 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
57 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
58 |
+
old_value="(eval echo \\$$name)";
|
59 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
60 |
+
was_bool=true;
|
61 |
+
else
|
62 |
+
was_bool=false;
|
63 |
+
fi
|
64 |
+
|
65 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
66 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
67 |
+
eval "${name}=\"$2\"";
|
68 |
+
|
69 |
+
# Check that Boolean-valued arguments are really Boolean.
|
70 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
71 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
72 |
+
exit 1;
|
73 |
+
fi
|
74 |
+
shift 2;
|
75 |
+
;;
|
76 |
+
|
77 |
+
*) break;
|
78 |
+
esac
|
79 |
+
done
|
80 |
+
|
81 |
+
file_dir="${work_dir}/${file_folder_name}"
|
82 |
+
final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
|
83 |
+
|
84 |
+
train_dataset="${file_dir}/train.xlsx"
|
85 |
+
valid_dataset="${file_dir}/valid.xlsx"
|
86 |
+
vocabulary_dir="${file_dir}/vocabulary"
|
87 |
+
|
88 |
+
|
89 |
+
$verbose && echo "system_version: ${system_version}"
|
90 |
+
$verbose && echo "file_folder_name: ${file_folder_name}"
|
91 |
+
|
92 |
+
if [ $system_version == "windows" ]; then
|
93 |
+
alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
|
94 |
+
elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
|
95 |
+
#source /data/local/bin/vm_sound_classification/bin/activate
|
96 |
+
alias python3='/data/local/bin/vm_sound_classification/bin/python3'
|
97 |
+
fi
|
98 |
+
|
99 |
+
|
100 |
+
if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
|
101 |
+
$verbose && echo "stage 0: prepare data"
|
102 |
+
cd "${work_dir}" || exit 1
|
103 |
+
python3 step_1_prepare_data.py \
|
104 |
+
--file_dir "${file_dir}" \
|
105 |
+
--filename_patterns "${filename_patterns}" \
|
106 |
+
--train_dataset "${train_dataset}" \
|
107 |
+
--valid_dataset "${valid_dataset}" \
|
108 |
+
|
109 |
+
fi
|
110 |
+
|
111 |
+
|
112 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
113 |
+
$verbose && echo "stage 1: make vocabulary"
|
114 |
+
cd "${work_dir}" || exit 1
|
115 |
+
python3 step_2_make_vocabulary.py \
|
116 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
117 |
+
--train_dataset "${train_dataset}" \
|
118 |
+
--valid_dataset "${valid_dataset}" \
|
119 |
+
|
120 |
+
fi
|
121 |
+
|
122 |
+
|
123 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
124 |
+
$verbose && echo "stage 2: train global model"
|
125 |
+
cd "${work_dir}" || exit 1
|
126 |
+
python3 step_3_train_global_model.py \
|
127 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
128 |
+
--train_dataset "${train_dataset}" \
|
129 |
+
--valid_dataset "${valid_dataset}" \
|
130 |
+
--serialization_dir "${file_dir}/global_model" \
|
131 |
+
|
132 |
+
fi
|
133 |
+
|
134 |
+
|
135 |
+
if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
|
136 |
+
$verbose && echo "stage 3: train country model"
|
137 |
+
cd "${work_dir}" || exit 1
|
138 |
+
python3 step_4_train_country_model.py \
|
139 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
140 |
+
--train_dataset "${train_dataset}" \
|
141 |
+
--valid_dataset "${valid_dataset}" \
|
142 |
+
--country "${country}" \
|
143 |
+
--serialization_dir "${file_dir}/country_model" \
|
144 |
+
|
145 |
+
fi
|
146 |
+
|
147 |
+
|
148 |
+
if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
|
149 |
+
$verbose && echo "stage 4: train union model"
|
150 |
+
cd "${work_dir}" || exit 1
|
151 |
+
python3 step_5_train_union.py \
|
152 |
+
--vocabulary_dir "${vocabulary_dir}" \
|
153 |
+
--train_dataset "${train_dataset}" \
|
154 |
+
--valid_dataset "${valid_dataset}" \
|
155 |
+
--serialization_dir "${file_dir}/union" \
|
156 |
+
|
157 |
+
fi
|
examples/vm_sound_classification8/step_1_prepare_data.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from glob import glob
|
5 |
+
import json
|
6 |
+
import os
|
7 |
+
from pathlib import Path
|
8 |
+
import random
|
9 |
+
import sys
|
10 |
+
|
11 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
12 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
from scipy.io import wavfile
|
16 |
+
from tqdm import tqdm
|
17 |
+
|
18 |
+
|
19 |
+
def get_args():
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument("--file_dir", default="./", type=str)
|
22 |
+
parser.add_argument("--task", default="default", type=str)
|
23 |
+
parser.add_argument("--filename_patterns", type=str)
|
24 |
+
|
25 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
26 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
return args
|
30 |
+
|
31 |
+
|
32 |
+
def get_dataset(args):
|
33 |
+
filename_patterns = args.filename_patterns
|
34 |
+
filename_patterns = filename_patterns.split(" ")
|
35 |
+
print(filename_patterns)
|
36 |
+
|
37 |
+
file_dir = Path(args.file_dir)
|
38 |
+
file_dir.mkdir(exist_ok=True)
|
39 |
+
|
40 |
+
global_label_map = {
|
41 |
+
"bell": "bell",
|
42 |
+
"white_noise": "white_noise",
|
43 |
+
"low_white_noise": "white_noise",
|
44 |
+
"high_white_noise": "noise",
|
45 |
+
"music": "music",
|
46 |
+
"mute": "mute",
|
47 |
+
"noise": "noise",
|
48 |
+
"noise_mute": "noise_mute",
|
49 |
+
"voice": "voice",
|
50 |
+
"voicemail": "voicemail",
|
51 |
+
}
|
52 |
+
|
53 |
+
country_label_map = {
|
54 |
+
"bell": "voicemail",
|
55 |
+
"white_noise": "non_voicemail",
|
56 |
+
"low_white_noise": "non_voicemail",
|
57 |
+
"hight_white_noise": "non_voicemail",
|
58 |
+
"music": "non_voicemail",
|
59 |
+
"mute": "non_voicemail",
|
60 |
+
"noise": "non_voicemail",
|
61 |
+
"noise_mute": "non_voicemail",
|
62 |
+
"voice": "non_voicemail",
|
63 |
+
"voicemail": "voicemail",
|
64 |
+
"non_voicemail": "non_voicemail",
|
65 |
+
}
|
66 |
+
|
67 |
+
result = list()
|
68 |
+
for filename_pattern in filename_patterns:
|
69 |
+
filename_list = glob(filename_pattern)
|
70 |
+
for filename in tqdm(filename_list):
|
71 |
+
filename = Path(filename)
|
72 |
+
sample_rate, signal = wavfile.read(filename.as_posix())
|
73 |
+
if len(signal) < sample_rate * 2:
|
74 |
+
continue
|
75 |
+
|
76 |
+
folder = filename.parts[-2]
|
77 |
+
country = filename.parts[-4]
|
78 |
+
|
79 |
+
if folder not in global_label_map.keys():
|
80 |
+
continue
|
81 |
+
if folder not in country_label_map.keys():
|
82 |
+
continue
|
83 |
+
|
84 |
+
global_label = global_label_map[folder]
|
85 |
+
country_label = country_label_map[folder]
|
86 |
+
|
87 |
+
random1 = random.random()
|
88 |
+
random2 = random.random()
|
89 |
+
|
90 |
+
result.append({
|
91 |
+
"filename": filename,
|
92 |
+
"folder": folder,
|
93 |
+
"category": country,
|
94 |
+
"global_labels": global_label,
|
95 |
+
"country_labels": country_label,
|
96 |
+
"random1": random1,
|
97 |
+
"random2": random2,
|
98 |
+
"flag": "TRAIN" if random2 < 0.8 else "TEST",
|
99 |
+
})
|
100 |
+
|
101 |
+
df = pd.DataFrame(result)
|
102 |
+
pivot_table = pd.pivot_table(df, index=["global_labels"], values=["filename"], aggfunc="count")
|
103 |
+
print(pivot_table)
|
104 |
+
|
105 |
+
df = df.sort_values(by=["random1"], ascending=False)
|
106 |
+
df.to_excel(
|
107 |
+
file_dir / "dataset.xlsx",
|
108 |
+
index=False,
|
109 |
+
# encoding="utf_8_sig"
|
110 |
+
)
|
111 |
+
|
112 |
+
return
|
113 |
+
|
114 |
+
|
115 |
+
def split_dataset(args):
|
116 |
+
"""分割训练集, 测试集"""
|
117 |
+
file_dir = Path(args.file_dir)
|
118 |
+
file_dir.mkdir(exist_ok=True)
|
119 |
+
|
120 |
+
df = pd.read_excel(file_dir / "dataset.xlsx")
|
121 |
+
|
122 |
+
train = list()
|
123 |
+
test = list()
|
124 |
+
|
125 |
+
for i, row in df.iterrows():
|
126 |
+
flag = row["flag"]
|
127 |
+
if flag == "TRAIN":
|
128 |
+
train.append(row)
|
129 |
+
else:
|
130 |
+
test.append(row)
|
131 |
+
|
132 |
+
train = pd.DataFrame(train)
|
133 |
+
train.to_excel(
|
134 |
+
args.train_dataset,
|
135 |
+
index=False,
|
136 |
+
# encoding="utf_8_sig"
|
137 |
+
)
|
138 |
+
test = pd.DataFrame(test)
|
139 |
+
test.to_excel(
|
140 |
+
args.valid_dataset,
|
141 |
+
index=False,
|
142 |
+
# encoding="utf_8_sig"
|
143 |
+
)
|
144 |
+
|
145 |
+
return
|
146 |
+
|
147 |
+
|
148 |
+
def main():
|
149 |
+
args = get_args()
|
150 |
+
get_dataset(args)
|
151 |
+
split_dataset(args)
|
152 |
+
return
|
153 |
+
|
154 |
+
|
155 |
+
if __name__ == "__main__":
|
156 |
+
main()
|
examples/vm_sound_classification8/step_2_make_vocabulary.py
ADDED
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
import os
|
5 |
+
from pathlib import Path
|
6 |
+
import sys
|
7 |
+
|
8 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
9 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
10 |
+
|
11 |
+
import pandas as pd
|
12 |
+
|
13 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
14 |
+
|
15 |
+
|
16 |
+
def get_args():
|
17 |
+
parser = argparse.ArgumentParser()
|
18 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
19 |
+
|
20 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
21 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
22 |
+
|
23 |
+
args = parser.parse_args()
|
24 |
+
return args
|
25 |
+
|
26 |
+
|
27 |
+
def main():
|
28 |
+
args = get_args()
|
29 |
+
|
30 |
+
train_dataset = pd.read_excel(args.train_dataset)
|
31 |
+
valid_dataset = pd.read_excel(args.valid_dataset)
|
32 |
+
|
33 |
+
# non_padded_namespaces
|
34 |
+
category_set = set()
|
35 |
+
for i, row in train_dataset.iterrows():
|
36 |
+
category = row["category"]
|
37 |
+
category_set.add(category)
|
38 |
+
|
39 |
+
for i, row in valid_dataset.iterrows():
|
40 |
+
category = row["category"]
|
41 |
+
category_set.add(category)
|
42 |
+
|
43 |
+
vocabulary = Vocabulary(non_padded_namespaces=["global_labels", *list(category_set)])
|
44 |
+
|
45 |
+
# train
|
46 |
+
for i, row in train_dataset.iterrows():
|
47 |
+
global_labels = row["global_labels"]
|
48 |
+
country_labels = row["country_labels"]
|
49 |
+
category = row["category"]
|
50 |
+
|
51 |
+
vocabulary.add_token_to_namespace(global_labels, "global_labels")
|
52 |
+
vocabulary.add_token_to_namespace(country_labels, category)
|
53 |
+
|
54 |
+
# valid
|
55 |
+
for i, row in valid_dataset.iterrows():
|
56 |
+
global_labels = row["global_labels"]
|
57 |
+
country_labels = row["country_labels"]
|
58 |
+
category = row["category"]
|
59 |
+
|
60 |
+
vocabulary.add_token_to_namespace(global_labels, "global_labels")
|
61 |
+
vocabulary.add_token_to_namespace(country_labels, category)
|
62 |
+
|
63 |
+
vocabulary.save_to_files(args.vocabulary_dir)
|
64 |
+
|
65 |
+
return
|
66 |
+
|
67 |
+
|
68 |
+
if __name__ == "__main__":
|
69 |
+
main()
|
examples/vm_sound_classification8/step_3_train_global_model.py
ADDED
@@ -0,0 +1,328 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
之前的代码达到准确率0.8423
|
5 |
+
此代码达到准确率0.8379
|
6 |
+
此代码可行.
|
7 |
+
"""
|
8 |
+
import argparse
|
9 |
+
import copy
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
from logging.handlers import TimedRotatingFileHandler
|
13 |
+
import os
|
14 |
+
from pathlib import Path
|
15 |
+
import platform
|
16 |
+
import sys
|
17 |
+
from typing import List
|
18 |
+
|
19 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
20 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
21 |
+
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
27 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
28 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
29 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
30 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
31 |
+
|
32 |
+
|
33 |
+
def get_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
36 |
+
|
37 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
38 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
39 |
+
|
40 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
41 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
42 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
43 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
44 |
+
parser.add_argument("--patience", default=5, type=int)
|
45 |
+
parser.add_argument("--serialization_dir", default="global_classifier", type=str)
|
46 |
+
parser.add_argument("--seed", default=0, type=int)
|
47 |
+
|
48 |
+
args = parser.parse_args()
|
49 |
+
return args
|
50 |
+
|
51 |
+
|
52 |
+
def logging_config(file_dir: str):
|
53 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
54 |
+
|
55 |
+
logging.basicConfig(format=fmt,
|
56 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
57 |
+
level=logging.DEBUG)
|
58 |
+
file_handler = TimedRotatingFileHandler(
|
59 |
+
filename=os.path.join(file_dir, "main.log"),
|
60 |
+
encoding="utf-8",
|
61 |
+
when="D",
|
62 |
+
interval=1,
|
63 |
+
backupCount=7
|
64 |
+
)
|
65 |
+
file_handler.setLevel(logging.INFO)
|
66 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
67 |
+
logger = logging.getLogger(__name__)
|
68 |
+
logger.addHandler(file_handler)
|
69 |
+
|
70 |
+
return logger
|
71 |
+
|
72 |
+
|
73 |
+
class CollateFunction(object):
|
74 |
+
def __init__(self):
|
75 |
+
pass
|
76 |
+
|
77 |
+
def __call__(self, batch: List[dict]):
|
78 |
+
array_list = list()
|
79 |
+
label_list = list()
|
80 |
+
for sample in batch:
|
81 |
+
array = sample["waveform"]
|
82 |
+
label = sample["label"]
|
83 |
+
|
84 |
+
array_list.append(array)
|
85 |
+
label_list.append(label)
|
86 |
+
|
87 |
+
array_list = torch.stack(array_list)
|
88 |
+
label_list = torch.stack(label_list)
|
89 |
+
return array_list, label_list
|
90 |
+
|
91 |
+
|
92 |
+
collate_fn = CollateFunction()
|
93 |
+
|
94 |
+
|
95 |
+
def main():
|
96 |
+
args = get_args()
|
97 |
+
|
98 |
+
serialization_dir = Path(args.serialization_dir)
|
99 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
100 |
+
|
101 |
+
logger = logging_config(args.serialization_dir)
|
102 |
+
|
103 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
104 |
+
n_gpu = torch.cuda.device_count()
|
105 |
+
logger.info("GPU available: {}; device: {}".format(n_gpu, device))
|
106 |
+
|
107 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
108 |
+
|
109 |
+
# datasets
|
110 |
+
train_dataset = WaveClassifierExcelDataset(
|
111 |
+
vocab=vocabulary,
|
112 |
+
excel_file=args.train_dataset,
|
113 |
+
category=None,
|
114 |
+
category_field="category",
|
115 |
+
label_field="global_labels",
|
116 |
+
expected_sample_rate=8000,
|
117 |
+
max_wave_value=32768.0,
|
118 |
+
)
|
119 |
+
valid_dataset = WaveClassifierExcelDataset(
|
120 |
+
vocab=vocabulary,
|
121 |
+
excel_file=args.valid_dataset,
|
122 |
+
category=None,
|
123 |
+
category_field="category",
|
124 |
+
label_field="global_labels",
|
125 |
+
expected_sample_rate=8000,
|
126 |
+
max_wave_value=32768.0,
|
127 |
+
)
|
128 |
+
|
129 |
+
train_data_loader = DataLoader(
|
130 |
+
dataset=train_dataset,
|
131 |
+
batch_size=args.batch_size,
|
132 |
+
shuffle=True,
|
133 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
134 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
135 |
+
collate_fn=collate_fn,
|
136 |
+
pin_memory=False,
|
137 |
+
# prefetch_factor=64,
|
138 |
+
)
|
139 |
+
valid_data_loader = DataLoader(
|
140 |
+
dataset=valid_dataset,
|
141 |
+
batch_size=args.batch_size,
|
142 |
+
shuffle=True,
|
143 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
144 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
145 |
+
collate_fn=collate_fn,
|
146 |
+
pin_memory=False,
|
147 |
+
# prefetch_factor=64,
|
148 |
+
)
|
149 |
+
|
150 |
+
# models - classifier
|
151 |
+
wave_encoder = WaveEncoder(
|
152 |
+
conv1d_block_param_list=[
|
153 |
+
{
|
154 |
+
'batch_norm': True,
|
155 |
+
'in_channels': 80,
|
156 |
+
'out_channels': 16,
|
157 |
+
'kernel_size': 3,
|
158 |
+
'stride': 3,
|
159 |
+
# 'padding': 'same',
|
160 |
+
'activation': 'relu',
|
161 |
+
'dropout': 0.1,
|
162 |
+
},
|
163 |
+
{
|
164 |
+
# 'batch_norm': True,
|
165 |
+
'in_channels': 16,
|
166 |
+
'out_channels': 16,
|
167 |
+
'kernel_size': 3,
|
168 |
+
'stride': 3,
|
169 |
+
# 'padding': 'same',
|
170 |
+
'activation': 'relu',
|
171 |
+
'dropout': 0.1,
|
172 |
+
},
|
173 |
+
{
|
174 |
+
# 'batch_norm': True,
|
175 |
+
'in_channels': 16,
|
176 |
+
'out_channels': 16,
|
177 |
+
'kernel_size': 3,
|
178 |
+
'stride': 3,
|
179 |
+
# 'padding': 'same',
|
180 |
+
'activation': 'relu',
|
181 |
+
'dropout': 0.1,
|
182 |
+
},
|
183 |
+
],
|
184 |
+
mel_spectrogram_param={
|
185 |
+
"sample_rate": 8000,
|
186 |
+
"n_fft": 512,
|
187 |
+
"win_length": 200,
|
188 |
+
"hop_length": 80,
|
189 |
+
"f_min": 10,
|
190 |
+
"f_max": 3800,
|
191 |
+
"window_fn": "hamming",
|
192 |
+
"n_mels": 80,
|
193 |
+
}
|
194 |
+
)
|
195 |
+
cls_head = ClsHead(
|
196 |
+
input_dim=16,
|
197 |
+
num_layers=2,
|
198 |
+
hidden_dims=[32, 16],
|
199 |
+
activations="relu",
|
200 |
+
dropout=0.1,
|
201 |
+
num_labels=vocabulary.get_vocab_size(namespace="global_labels")
|
202 |
+
)
|
203 |
+
model = WaveClassifier(
|
204 |
+
wave_encoder=wave_encoder,
|
205 |
+
cls_head=cls_head,
|
206 |
+
)
|
207 |
+
model.to(device)
|
208 |
+
|
209 |
+
# optimizer
|
210 |
+
optimizer = torch.optim.Adam(
|
211 |
+
model.parameters(),
|
212 |
+
lr=args.learning_rate
|
213 |
+
)
|
214 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
215 |
+
optimizer,
|
216 |
+
step_size=30000
|
217 |
+
)
|
218 |
+
focal_loss = FocalLoss(
|
219 |
+
num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
|
220 |
+
reduction="mean",
|
221 |
+
)
|
222 |
+
categorical_accuracy = CategoricalAccuracy()
|
223 |
+
|
224 |
+
# training
|
225 |
+
best_idx_epoch: int = None
|
226 |
+
best_accuracy: float = None
|
227 |
+
patience_count = 0
|
228 |
+
global_step = 0
|
229 |
+
model_filename_list = list()
|
230 |
+
for idx_epoch in range(args.max_epochs):
|
231 |
+
|
232 |
+
# training
|
233 |
+
model.train()
|
234 |
+
total_loss = 0
|
235 |
+
total_examples = 0
|
236 |
+
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
|
237 |
+
input_ids, label_ids = batch
|
238 |
+
input_ids = input_ids.to(device)
|
239 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
240 |
+
|
241 |
+
logits = model.forward(input_ids)
|
242 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
243 |
+
categorical_accuracy(logits, label_ids)
|
244 |
+
|
245 |
+
total_loss += loss.item()
|
246 |
+
total_examples += input_ids.size(0)
|
247 |
+
|
248 |
+
optimizer.zero_grad()
|
249 |
+
loss.backward()
|
250 |
+
optimizer.step()
|
251 |
+
lr_scheduler.step()
|
252 |
+
|
253 |
+
global_step += 1
|
254 |
+
training_loss = total_loss / total_examples
|
255 |
+
training_loss = round(training_loss, 4)
|
256 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
257 |
+
training_accuracy = round(training_accuracy, 4)
|
258 |
+
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
|
259 |
+
idx_epoch, training_loss, training_accuracy
|
260 |
+
))
|
261 |
+
|
262 |
+
# evaluation
|
263 |
+
model.eval()
|
264 |
+
total_loss = 0
|
265 |
+
total_examples = 0
|
266 |
+
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
|
267 |
+
input_ids, label_ids = batch
|
268 |
+
input_ids = input_ids.to(device)
|
269 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
270 |
+
|
271 |
+
with torch.no_grad():
|
272 |
+
logits = model.forward(input_ids)
|
273 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
274 |
+
categorical_accuracy(logits, label_ids)
|
275 |
+
|
276 |
+
total_loss += loss.item()
|
277 |
+
total_examples += input_ids.size(0)
|
278 |
+
|
279 |
+
evaluation_loss = total_loss / total_examples
|
280 |
+
evaluation_loss = round(evaluation_loss, 4)
|
281 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
282 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
283 |
+
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
284 |
+
idx_epoch, evaluation_loss, evaluation_accuracy
|
285 |
+
))
|
286 |
+
|
287 |
+
# save metric
|
288 |
+
metrics = {
|
289 |
+
"training_loss": training_loss,
|
290 |
+
"training_accuracy": training_accuracy,
|
291 |
+
"evaluation_loss": evaluation_loss,
|
292 |
+
"evaluation_accuracy": evaluation_accuracy,
|
293 |
+
"best_idx_epoch": best_idx_epoch,
|
294 |
+
"best_accuracy": best_accuracy,
|
295 |
+
}
|
296 |
+
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
|
297 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
298 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
299 |
+
|
300 |
+
# save model
|
301 |
+
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
|
302 |
+
model_filename_list.append(model_filename)
|
303 |
+
if len(model_filename_list) >= args.num_serialized_models_to_keep:
|
304 |
+
model_filename_to_delete = model_filename_list.pop(0)
|
305 |
+
os.remove(model_filename_to_delete)
|
306 |
+
torch.save(model.state_dict(), model_filename)
|
307 |
+
|
308 |
+
# early stop
|
309 |
+
best_model_filename = os.path.join(args.serialization_dir, "best.bin")
|
310 |
+
if best_accuracy is None:
|
311 |
+
best_idx_epoch = idx_epoch
|
312 |
+
best_accuracy = evaluation_accuracy
|
313 |
+
torch.save(model.state_dict(), best_model_filename)
|
314 |
+
elif evaluation_accuracy > best_accuracy:
|
315 |
+
best_idx_epoch = idx_epoch
|
316 |
+
best_accuracy = evaluation_accuracy
|
317 |
+
torch.save(model.state_dict(), best_model_filename)
|
318 |
+
patience_count = 0
|
319 |
+
elif patience_count >= args.patience:
|
320 |
+
break
|
321 |
+
else:
|
322 |
+
patience_count += 1
|
323 |
+
|
324 |
+
return
|
325 |
+
|
326 |
+
|
327 |
+
if __name__ == "__main__":
|
328 |
+
main()
|
examples/vm_sound_classification8/step_4_train_country_model.py
ADDED
@@ -0,0 +1,349 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
"""
|
4 |
+
只训练 cls_head 部分的参数, 模型的准确率会更低.
|
5 |
+
"""
|
6 |
+
import argparse
|
7 |
+
from collections import defaultdict
|
8 |
+
import json
|
9 |
+
import logging
|
10 |
+
from logging.handlers import TimedRotatingFileHandler
|
11 |
+
import os
|
12 |
+
import platform
|
13 |
+
from pathlib import Path
|
14 |
+
import sys
|
15 |
+
import shutil
|
16 |
+
from typing import List
|
17 |
+
|
18 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
19 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
20 |
+
|
21 |
+
import pandas as pd
|
22 |
+
import torch
|
23 |
+
from torch.utils.data.dataloader import DataLoader
|
24 |
+
from tqdm import tqdm
|
25 |
+
|
26 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
27 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
28 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
29 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
30 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
31 |
+
|
32 |
+
|
33 |
+
def get_args():
|
34 |
+
parser = argparse.ArgumentParser()
|
35 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
36 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
37 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
38 |
+
|
39 |
+
parser.add_argument("--country", default="en-US", type=str)
|
40 |
+
parser.add_argument("--shared_encoder", default="file_dir/global_model/best.bin", type=str)
|
41 |
+
|
42 |
+
parser.add_argument("--max_epochs", default=100, type=int)
|
43 |
+
parser.add_argument("--batch_size", default=64, type=int)
|
44 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
45 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
46 |
+
parser.add_argument("--patience", default=5, type=int)
|
47 |
+
parser.add_argument("--serialization_dir", default="country_models", type=str)
|
48 |
+
parser.add_argument("--seed", default=0, type=int)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
def logging_config(file_dir: str):
|
55 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
56 |
+
|
57 |
+
logging.basicConfig(format=fmt,
|
58 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
59 |
+
level=logging.DEBUG)
|
60 |
+
file_handler = TimedRotatingFileHandler(
|
61 |
+
filename=os.path.join(file_dir, "main.log"),
|
62 |
+
encoding="utf-8",
|
63 |
+
when="D",
|
64 |
+
interval=1,
|
65 |
+
backupCount=7
|
66 |
+
)
|
67 |
+
file_handler.setLevel(logging.INFO)
|
68 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
69 |
+
logger = logging.getLogger(__name__)
|
70 |
+
logger.addHandler(file_handler)
|
71 |
+
|
72 |
+
return logger
|
73 |
+
|
74 |
+
|
75 |
+
class CollateFunction(object):
|
76 |
+
def __init__(self):
|
77 |
+
pass
|
78 |
+
|
79 |
+
def __call__(self, batch: List[dict]):
|
80 |
+
array_list = list()
|
81 |
+
label_list = list()
|
82 |
+
for sample in batch:
|
83 |
+
array = sample['waveform']
|
84 |
+
label = sample['label']
|
85 |
+
|
86 |
+
array_list.append(array)
|
87 |
+
label_list.append(label)
|
88 |
+
|
89 |
+
array_list = torch.stack(array_list)
|
90 |
+
label_list = torch.stack(label_list)
|
91 |
+
return array_list, label_list
|
92 |
+
|
93 |
+
|
94 |
+
collate_fn = CollateFunction()
|
95 |
+
|
96 |
+
|
97 |
+
def main():
|
98 |
+
args = get_args()
|
99 |
+
|
100 |
+
serialization_dir = Path(args.serialization_dir)
|
101 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
102 |
+
|
103 |
+
logger = logging_config(args.serialization_dir)
|
104 |
+
|
105 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
106 |
+
n_gpu = torch.cuda.device_count()
|
107 |
+
logger.info("GPU available: {}; device: {}".format(n_gpu, device))
|
108 |
+
|
109 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
110 |
+
|
111 |
+
# datasets
|
112 |
+
logger.info("prepare datasets")
|
113 |
+
train_dataset = WaveClassifierExcelDataset(
|
114 |
+
vocab=vocabulary,
|
115 |
+
excel_file=args.train_dataset,
|
116 |
+
category=args.country,
|
117 |
+
category_field="category",
|
118 |
+
label_field="country_labels",
|
119 |
+
expected_sample_rate=8000,
|
120 |
+
max_wave_value=32768.0,
|
121 |
+
)
|
122 |
+
valid_dataset = WaveClassifierExcelDataset(
|
123 |
+
vocab=vocabulary,
|
124 |
+
excel_file=args.valid_dataset,
|
125 |
+
category=args.country,
|
126 |
+
category_field="category",
|
127 |
+
label_field="country_labels",
|
128 |
+
expected_sample_rate=8000,
|
129 |
+
max_wave_value=32768.0,
|
130 |
+
)
|
131 |
+
|
132 |
+
train_data_loader = DataLoader(
|
133 |
+
dataset=train_dataset,
|
134 |
+
batch_size=args.batch_size,
|
135 |
+
shuffle=True,
|
136 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
137 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
138 |
+
collate_fn=collate_fn,
|
139 |
+
pin_memory=False,
|
140 |
+
# prefetch_factor=64,
|
141 |
+
)
|
142 |
+
valid_data_loader = DataLoader(
|
143 |
+
dataset=valid_dataset,
|
144 |
+
batch_size=args.batch_size,
|
145 |
+
shuffle=True,
|
146 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
147 |
+
num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
|
148 |
+
collate_fn=collate_fn,
|
149 |
+
pin_memory=False,
|
150 |
+
# prefetch_factor=64,
|
151 |
+
)
|
152 |
+
|
153 |
+
# models - classifier
|
154 |
+
wave_encoder = WaveEncoder(
|
155 |
+
conv1d_block_param_list=[
|
156 |
+
{
|
157 |
+
'batch_norm': True,
|
158 |
+
'in_channels': 80,
|
159 |
+
'out_channels': 16,
|
160 |
+
'kernel_size': 3,
|
161 |
+
'stride': 3,
|
162 |
+
# 'padding': 'same',
|
163 |
+
'activation': 'relu',
|
164 |
+
'dropout': 0.1,
|
165 |
+
},
|
166 |
+
{
|
167 |
+
# 'batch_norm': True,
|
168 |
+
'in_channels': 16,
|
169 |
+
'out_channels': 16,
|
170 |
+
'kernel_size': 3,
|
171 |
+
'stride': 3,
|
172 |
+
# 'padding': 'same',
|
173 |
+
'activation': 'relu',
|
174 |
+
'dropout': 0.1,
|
175 |
+
},
|
176 |
+
{
|
177 |
+
# 'batch_norm': True,
|
178 |
+
'in_channels': 16,
|
179 |
+
'out_channels': 16,
|
180 |
+
'kernel_size': 3,
|
181 |
+
'stride': 3,
|
182 |
+
# 'padding': 'same',
|
183 |
+
'activation': 'relu',
|
184 |
+
'dropout': 0.1,
|
185 |
+
},
|
186 |
+
],
|
187 |
+
mel_spectrogram_param={
|
188 |
+
"sample_rate": 8000,
|
189 |
+
"n_fft": 512,
|
190 |
+
"win_length": 200,
|
191 |
+
"hop_length": 80,
|
192 |
+
"f_min": 10,
|
193 |
+
"f_max": 3800,
|
194 |
+
"window_fn": "hamming",
|
195 |
+
"n_mels": 80,
|
196 |
+
}
|
197 |
+
)
|
198 |
+
|
199 |
+
with open(args.shared_encoder, "rb") as f:
|
200 |
+
state_dict = torch.load(f, map_location=device)
|
201 |
+
processed_state_dict = dict()
|
202 |
+
prefix = "wave_encoder."
|
203 |
+
for k, v in state_dict.items():
|
204 |
+
if not str(k).startswith(prefix):
|
205 |
+
continue
|
206 |
+
k = k[len(prefix):]
|
207 |
+
processed_state_dict[k] = v
|
208 |
+
|
209 |
+
wave_encoder.load_state_dict(
|
210 |
+
state_dict=processed_state_dict,
|
211 |
+
strict=True,
|
212 |
+
)
|
213 |
+
cls_head = ClsHead(
|
214 |
+
input_dim=16,
|
215 |
+
num_layers=2,
|
216 |
+
hidden_dims=[32, 16],
|
217 |
+
activations="relu",
|
218 |
+
dropout=0.1,
|
219 |
+
num_labels=vocabulary.get_vocab_size(namespace="global_labels")
|
220 |
+
)
|
221 |
+
model = WaveClassifier(
|
222 |
+
wave_encoder=wave_encoder,
|
223 |
+
cls_head=cls_head,
|
224 |
+
)
|
225 |
+
model.wave_encoder.requires_grad_(requires_grad=False)
|
226 |
+
model.cls_head.requires_grad_(requires_grad=True)
|
227 |
+
model.to(device)
|
228 |
+
|
229 |
+
# optimizer
|
230 |
+
logger.info("prepare optimizer")
|
231 |
+
optimizer = torch.optim.Adam(
|
232 |
+
model.cls_head.parameters(),
|
233 |
+
lr=args.learning_rate,
|
234 |
+
)
|
235 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
236 |
+
optimizer,
|
237 |
+
step_size=2000
|
238 |
+
)
|
239 |
+
focal_loss = FocalLoss(
|
240 |
+
num_classes=vocabulary.get_vocab_size(namespace=args.country),
|
241 |
+
reduction="mean",
|
242 |
+
)
|
243 |
+
categorical_accuracy = CategoricalAccuracy()
|
244 |
+
|
245 |
+
# training loop
|
246 |
+
best_idx_epoch: int = None
|
247 |
+
best_accuracy: float = None
|
248 |
+
patience_count = 0
|
249 |
+
global_step = 0
|
250 |
+
model_filename_list = list()
|
251 |
+
for idx_epoch in range(args.max_epochs):
|
252 |
+
|
253 |
+
# training
|
254 |
+
model.train()
|
255 |
+
total_loss = 0
|
256 |
+
total_examples = 0
|
257 |
+
for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
|
258 |
+
input_ids, label_ids = batch
|
259 |
+
input_ids = input_ids.to(device)
|
260 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
261 |
+
|
262 |
+
logits = model.forward(input_ids)
|
263 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
264 |
+
categorical_accuracy(logits, label_ids)
|
265 |
+
|
266 |
+
total_loss += loss.item()
|
267 |
+
total_examples += input_ids.size(0)
|
268 |
+
|
269 |
+
optimizer.zero_grad()
|
270 |
+
loss.backward()
|
271 |
+
optimizer.step()
|
272 |
+
lr_scheduler.step()
|
273 |
+
|
274 |
+
global_step += 1
|
275 |
+
training_loss = total_loss / total_examples
|
276 |
+
training_loss = round(training_loss, 4)
|
277 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
278 |
+
training_accuracy = round(training_accuracy, 4)
|
279 |
+
logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
|
280 |
+
idx_epoch, training_loss, training_accuracy
|
281 |
+
))
|
282 |
+
|
283 |
+
# evaluation
|
284 |
+
model.eval()
|
285 |
+
total_loss = 0
|
286 |
+
total_examples = 0
|
287 |
+
for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
|
288 |
+
input_ids, label_ids = batch
|
289 |
+
input_ids = input_ids.to(device)
|
290 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
291 |
+
|
292 |
+
with torch.no_grad():
|
293 |
+
logits = model.forward(input_ids)
|
294 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
295 |
+
categorical_accuracy(logits, label_ids)
|
296 |
+
|
297 |
+
total_loss += loss.item()
|
298 |
+
total_examples += input_ids.size(0)
|
299 |
+
|
300 |
+
evaluation_loss = total_loss / total_examples
|
301 |
+
evaluation_loss = round(evaluation_loss, 4)
|
302 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
303 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
304 |
+
logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
305 |
+
idx_epoch, evaluation_loss, evaluation_accuracy
|
306 |
+
))
|
307 |
+
|
308 |
+
# save metric
|
309 |
+
metrics = {
|
310 |
+
"training_loss": training_loss,
|
311 |
+
"training_accuracy": training_accuracy,
|
312 |
+
"evaluation_loss": evaluation_loss,
|
313 |
+
"evaluation_accuracy": evaluation_accuracy,
|
314 |
+
"best_idx_epoch": best_idx_epoch,
|
315 |
+
"best_accuracy": best_accuracy,
|
316 |
+
}
|
317 |
+
metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
|
318 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
319 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
320 |
+
|
321 |
+
# save model
|
322 |
+
model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
|
323 |
+
model_filename_list.append(model_filename)
|
324 |
+
if len(model_filename_list) >= args.num_serialized_models_to_keep:
|
325 |
+
model_filename_to_delete = model_filename_list.pop(0)
|
326 |
+
os.remove(model_filename_to_delete)
|
327 |
+
torch.save(model.state_dict(), model_filename)
|
328 |
+
|
329 |
+
# early stop
|
330 |
+
best_model_filename = os.path.join(args.serialization_dir, "best.bin")
|
331 |
+
if best_accuracy is None:
|
332 |
+
best_idx_epoch = idx_epoch
|
333 |
+
best_accuracy = evaluation_accuracy
|
334 |
+
torch.save(model.state_dict(), best_model_filename)
|
335 |
+
elif evaluation_accuracy > best_accuracy:
|
336 |
+
best_idx_epoch = idx_epoch
|
337 |
+
best_accuracy = evaluation_accuracy
|
338 |
+
torch.save(model.state_dict(), best_model_filename)
|
339 |
+
patience_count = 0
|
340 |
+
elif patience_count >= args.patience:
|
341 |
+
break
|
342 |
+
else:
|
343 |
+
patience_count += 1
|
344 |
+
|
345 |
+
return
|
346 |
+
|
347 |
+
|
348 |
+
if __name__ == "__main__":
|
349 |
+
main()
|
examples/vm_sound_classification8/step_5_train_union.py
ADDED
@@ -0,0 +1,499 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from collections import defaultdict
|
5 |
+
import json
|
6 |
+
import logging
|
7 |
+
from logging.handlers import TimedRotatingFileHandler
|
8 |
+
import os
|
9 |
+
import platform
|
10 |
+
from pathlib import Path
|
11 |
+
import sys
|
12 |
+
import shutil
|
13 |
+
from typing import List
|
14 |
+
|
15 |
+
pwd = os.path.abspath(os.path.dirname(__file__))
|
16 |
+
sys.path.append(os.path.join(pwd, "../../"))
|
17 |
+
|
18 |
+
import pandas as pd
|
19 |
+
import torch
|
20 |
+
from torch.utils.data.dataloader import DataLoader
|
21 |
+
from tqdm import tqdm
|
22 |
+
|
23 |
+
from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
|
24 |
+
from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
|
25 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
26 |
+
from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
|
27 |
+
from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
|
28 |
+
|
29 |
+
|
30 |
+
def get_args():
|
31 |
+
parser = argparse.ArgumentParser()
|
32 |
+
parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
|
33 |
+
|
34 |
+
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
|
35 |
+
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
|
36 |
+
|
37 |
+
parser.add_argument("--max_steps", default=100000, type=int)
|
38 |
+
parser.add_argument("--save_steps", default=30, type=int)
|
39 |
+
|
40 |
+
parser.add_argument("--batch_size", default=1, type=int)
|
41 |
+
parser.add_argument("--learning_rate", default=1e-3, type=float)
|
42 |
+
parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
|
43 |
+
parser.add_argument("--patience", default=5, type=int)
|
44 |
+
parser.add_argument("--serialization_dir", default="union", type=str)
|
45 |
+
parser.add_argument("--seed", default=0, type=int)
|
46 |
+
|
47 |
+
parser.add_argument("--num_workers", default=0, type=int)
|
48 |
+
|
49 |
+
args = parser.parse_args()
|
50 |
+
return args
|
51 |
+
|
52 |
+
|
53 |
+
def logging_config(file_dir: str):
|
54 |
+
fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
|
55 |
+
|
56 |
+
logging.basicConfig(format=fmt,
|
57 |
+
datefmt="%m/%d/%Y %H:%M:%S",
|
58 |
+
level=logging.DEBUG)
|
59 |
+
file_handler = TimedRotatingFileHandler(
|
60 |
+
filename=os.path.join(file_dir, "main.log"),
|
61 |
+
encoding="utf-8",
|
62 |
+
when="D",
|
63 |
+
interval=1,
|
64 |
+
backupCount=7
|
65 |
+
)
|
66 |
+
file_handler.setLevel(logging.INFO)
|
67 |
+
file_handler.setFormatter(logging.Formatter(fmt))
|
68 |
+
logger = logging.getLogger(__name__)
|
69 |
+
logger.addHandler(file_handler)
|
70 |
+
|
71 |
+
return logger
|
72 |
+
|
73 |
+
|
74 |
+
class CollateFunction(object):
|
75 |
+
def __init__(self):
|
76 |
+
pass
|
77 |
+
|
78 |
+
def __call__(self, batch: List[dict]):
|
79 |
+
array_list = list()
|
80 |
+
label_list = list()
|
81 |
+
for sample in batch:
|
82 |
+
array = sample['waveform']
|
83 |
+
label = sample['label']
|
84 |
+
|
85 |
+
array_list.append(array)
|
86 |
+
label_list.append(label)
|
87 |
+
|
88 |
+
array_list = torch.stack(array_list)
|
89 |
+
label_list = torch.stack(label_list)
|
90 |
+
return array_list, label_list
|
91 |
+
|
92 |
+
|
93 |
+
collate_fn = CollateFunction()
|
94 |
+
|
95 |
+
|
96 |
+
class DatasetIterator(object):
|
97 |
+
def __init__(self, data_loader: DataLoader):
|
98 |
+
self.data_loader = data_loader
|
99 |
+
self.data_loader_iter = iter(self.data_loader)
|
100 |
+
|
101 |
+
def next(self):
|
102 |
+
try:
|
103 |
+
result = self.data_loader_iter.__next__()
|
104 |
+
except StopIteration:
|
105 |
+
self.data_loader_iter = iter(self.data_loader)
|
106 |
+
result = self.data_loader_iter.__next__()
|
107 |
+
return result
|
108 |
+
|
109 |
+
|
110 |
+
def main():
|
111 |
+
args = get_args()
|
112 |
+
|
113 |
+
serialization_dir = Path(args.serialization_dir)
|
114 |
+
serialization_dir.mkdir(parents=True, exist_ok=True)
|
115 |
+
|
116 |
+
logger = logging_config(args.serialization_dir)
|
117 |
+
|
118 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
119 |
+
n_gpu = torch.cuda.device_count()
|
120 |
+
logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
|
121 |
+
|
122 |
+
vocabulary = Vocabulary.from_files(args.vocabulary_dir)
|
123 |
+
namespaces = vocabulary._token_to_index.keys()
|
124 |
+
|
125 |
+
# namespace_to_ratio
|
126 |
+
max_radio = (len(namespaces) - 1) * 3
|
127 |
+
namespace_to_ratio = {n: 1 for n in namespaces}
|
128 |
+
namespace_to_ratio["global_labels"] = max_radio
|
129 |
+
|
130 |
+
# datasets
|
131 |
+
logger.info("prepare datasets")
|
132 |
+
namespace_to_datasets = dict()
|
133 |
+
for namespace in namespaces:
|
134 |
+
logger.info("prepare datasets - {}".format(namespace))
|
135 |
+
if namespace == "global_labels":
|
136 |
+
train_dataset = WaveClassifierExcelDataset(
|
137 |
+
vocab=vocabulary,
|
138 |
+
excel_file=args.train_dataset,
|
139 |
+
category=None,
|
140 |
+
category_field="category",
|
141 |
+
label_field="global_labels",
|
142 |
+
expected_sample_rate=8000,
|
143 |
+
max_wave_value=32768.0,
|
144 |
+
)
|
145 |
+
valid_dataset = WaveClassifierExcelDataset(
|
146 |
+
vocab=vocabulary,
|
147 |
+
excel_file=args.valid_dataset,
|
148 |
+
category=None,
|
149 |
+
category_field="category",
|
150 |
+
label_field="global_labels",
|
151 |
+
expected_sample_rate=8000,
|
152 |
+
max_wave_value=32768.0,
|
153 |
+
)
|
154 |
+
else:
|
155 |
+
train_dataset = WaveClassifierExcelDataset(
|
156 |
+
vocab=vocabulary,
|
157 |
+
excel_file=args.train_dataset,
|
158 |
+
category=namespace,
|
159 |
+
category_field="category",
|
160 |
+
label_field="country_labels",
|
161 |
+
expected_sample_rate=8000,
|
162 |
+
max_wave_value=32768.0,
|
163 |
+
)
|
164 |
+
valid_dataset = WaveClassifierExcelDataset(
|
165 |
+
vocab=vocabulary,
|
166 |
+
excel_file=args.valid_dataset,
|
167 |
+
category=namespace,
|
168 |
+
category_field="category",
|
169 |
+
label_field="country_labels",
|
170 |
+
expected_sample_rate=8000,
|
171 |
+
max_wave_value=32768.0,
|
172 |
+
)
|
173 |
+
|
174 |
+
train_data_loader = DataLoader(
|
175 |
+
dataset=train_dataset,
|
176 |
+
batch_size=args.batch_size,
|
177 |
+
shuffle=True,
|
178 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
179 |
+
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
180 |
+
num_workers=args.num_workers,
|
181 |
+
collate_fn=collate_fn,
|
182 |
+
pin_memory=False,
|
183 |
+
# prefetch_factor=64,
|
184 |
+
)
|
185 |
+
valid_data_loader = DataLoader(
|
186 |
+
dataset=valid_dataset,
|
187 |
+
batch_size=args.batch_size,
|
188 |
+
shuffle=True,
|
189 |
+
# Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
|
190 |
+
# num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
|
191 |
+
num_workers=args.num_workers,
|
192 |
+
collate_fn=collate_fn,
|
193 |
+
pin_memory=False,
|
194 |
+
# prefetch_factor=64,
|
195 |
+
)
|
196 |
+
|
197 |
+
namespace_to_datasets[namespace] = {
|
198 |
+
"train_data_loader": train_data_loader,
|
199 |
+
"valid_data_loader": valid_data_loader,
|
200 |
+
}
|
201 |
+
|
202 |
+
# datasets iterator
|
203 |
+
logger.info("prepare datasets iterator")
|
204 |
+
namespace_to_datasets_iter = dict()
|
205 |
+
for namespace in namespaces:
|
206 |
+
logger.info("prepare datasets iterator - {}".format(namespace))
|
207 |
+
train_data_loader = namespace_to_datasets[namespace]["train_data_loader"]
|
208 |
+
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
|
209 |
+
namespace_to_datasets_iter[namespace] = {
|
210 |
+
"train_data_loader_iter": DatasetIterator(train_data_loader),
|
211 |
+
"valid_data_loader_iter": DatasetIterator(valid_data_loader),
|
212 |
+
}
|
213 |
+
|
214 |
+
# models - encoder
|
215 |
+
logger.info("prepare models - encoder")
|
216 |
+
wave_encoder = WaveEncoder(
|
217 |
+
conv2d_block_param_list=[
|
218 |
+
{
|
219 |
+
"batch_norm": True,
|
220 |
+
"in_channels": 1,
|
221 |
+
"out_channels": 4,
|
222 |
+
"kernel_size": 3,
|
223 |
+
"stride": 1,
|
224 |
+
# "padding": "same",
|
225 |
+
"dilation": 3,
|
226 |
+
"activation": "relu",
|
227 |
+
"dropout": 0.1,
|
228 |
+
},
|
229 |
+
{
|
230 |
+
# "batch_norm": True,
|
231 |
+
"in_channels": 4,
|
232 |
+
"out_channels": 4,
|
233 |
+
"kernel_size": 5,
|
234 |
+
"stride": 2,
|
235 |
+
# "padding": "same",
|
236 |
+
"dilation": 3,
|
237 |
+
"activation": "relu",
|
238 |
+
"dropout": 0.1,
|
239 |
+
},
|
240 |
+
{
|
241 |
+
# "batch_norm": True,
|
242 |
+
"in_channels": 4,
|
243 |
+
"out_channels": 4,
|
244 |
+
"kernel_size": 3,
|
245 |
+
"stride": 1,
|
246 |
+
# "padding": "same",
|
247 |
+
"dilation": 2,
|
248 |
+
"activation": "relu",
|
249 |
+
"dropout": 0.1,
|
250 |
+
},
|
251 |
+
],
|
252 |
+
mel_spectrogram_param={
|
253 |
+
'sample_rate': 8000,
|
254 |
+
'n_fft': 512,
|
255 |
+
'win_length': 200,
|
256 |
+
'hop_length': 80,
|
257 |
+
'f_min': 10,
|
258 |
+
'f_max': 3800,
|
259 |
+
'window_fn': 'hamming',
|
260 |
+
'n_mels': 80,
|
261 |
+
}
|
262 |
+
)
|
263 |
+
|
264 |
+
# models - cls_head
|
265 |
+
logger.info("prepare models - cls_head")
|
266 |
+
namespace_to_cls_heads = dict()
|
267 |
+
for namespace in namespaces:
|
268 |
+
logger.info("prepare models - cls_head - {}".format(namespace))
|
269 |
+
cls_head = ClsHead(
|
270 |
+
input_dim=352,
|
271 |
+
num_layers=2,
|
272 |
+
hidden_dims=[128, 32],
|
273 |
+
activations="relu",
|
274 |
+
dropout=0.1,
|
275 |
+
num_labels=vocabulary.get_vocab_size(namespace=namespace)
|
276 |
+
)
|
277 |
+
namespace_to_cls_heads[namespace] = cls_head
|
278 |
+
|
279 |
+
# models - classifier
|
280 |
+
logger.info("prepare models - classifier")
|
281 |
+
namespace_to_classifier = dict()
|
282 |
+
for namespace in namespaces:
|
283 |
+
logger.info("prepare models - classifier - {}".format(namespace))
|
284 |
+
cls_head = namespace_to_cls_heads[namespace]
|
285 |
+
wave_classifier = WaveClassifier(
|
286 |
+
wave_encoder=wave_encoder,
|
287 |
+
cls_head=cls_head,
|
288 |
+
)
|
289 |
+
wave_classifier.to(device)
|
290 |
+
namespace_to_classifier[namespace] = wave_classifier
|
291 |
+
|
292 |
+
# optimizer
|
293 |
+
logger.info("prepare optimizer")
|
294 |
+
param_optimizer = list()
|
295 |
+
param_optimizer.extend(wave_encoder.parameters())
|
296 |
+
for _, cls_head in namespace_to_cls_heads.items():
|
297 |
+
param_optimizer.extend(cls_head.parameters())
|
298 |
+
|
299 |
+
optimizer = torch.optim.Adam(
|
300 |
+
param_optimizer,
|
301 |
+
lr=args.learning_rate,
|
302 |
+
)
|
303 |
+
lr_scheduler = torch.optim.lr_scheduler.StepLR(
|
304 |
+
optimizer,
|
305 |
+
step_size=10000
|
306 |
+
)
|
307 |
+
focal_loss = FocalLoss(
|
308 |
+
num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
|
309 |
+
reduction="mean",
|
310 |
+
)
|
311 |
+
|
312 |
+
# categorical_accuracy
|
313 |
+
logger.info("prepare categorical_accuracy")
|
314 |
+
namespace_to_categorical_accuracy = dict()
|
315 |
+
for namespace in namespaces:
|
316 |
+
categorical_accuracy = CategoricalAccuracy()
|
317 |
+
namespace_to_categorical_accuracy[namespace] = categorical_accuracy
|
318 |
+
|
319 |
+
# training loop
|
320 |
+
logger.info("prepare training loop")
|
321 |
+
|
322 |
+
model_list = list()
|
323 |
+
best_idx_step = None
|
324 |
+
best_accuracy = None
|
325 |
+
patience_count = 0
|
326 |
+
|
327 |
+
namespace_to_total_loss = defaultdict(float)
|
328 |
+
namespace_to_total_examples = defaultdict(int)
|
329 |
+
for idx_step in tqdm(range(args.max_steps)):
|
330 |
+
|
331 |
+
# training one step
|
332 |
+
loss: torch.Tensor = None
|
333 |
+
for namespace in namespaces:
|
334 |
+
train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"]
|
335 |
+
|
336 |
+
ratio = namespace_to_ratio[namespace]
|
337 |
+
model = namespace_to_classifier[namespace]
|
338 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
339 |
+
|
340 |
+
model.train()
|
341 |
+
|
342 |
+
for _ in range(ratio):
|
343 |
+
batch = train_data_loader_iter.next()
|
344 |
+
input_ids, label_ids = batch
|
345 |
+
input_ids = input_ids.to(device)
|
346 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
347 |
+
|
348 |
+
logits = model.forward(input_ids)
|
349 |
+
task_loss = focal_loss.forward(logits, label_ids.view(-1))
|
350 |
+
categorical_accuracy(logits, label_ids)
|
351 |
+
|
352 |
+
if loss is None:
|
353 |
+
loss = task_loss
|
354 |
+
else:
|
355 |
+
loss += task_loss
|
356 |
+
|
357 |
+
namespace_to_total_loss[namespace] += task_loss.item()
|
358 |
+
namespace_to_total_examples[namespace] += input_ids.size(0)
|
359 |
+
|
360 |
+
optimizer.zero_grad()
|
361 |
+
loss.backward()
|
362 |
+
optimizer.step()
|
363 |
+
lr_scheduler.step()
|
364 |
+
|
365 |
+
# logging
|
366 |
+
if (idx_step + 1) % args.save_steps == 0:
|
367 |
+
metrics = dict()
|
368 |
+
|
369 |
+
# training
|
370 |
+
for namespace in namespaces:
|
371 |
+
total_loss = namespace_to_total_loss[namespace]
|
372 |
+
total_examples = namespace_to_total_examples[namespace]
|
373 |
+
|
374 |
+
training_loss = total_loss / total_examples
|
375 |
+
training_loss = round(training_loss, 4)
|
376 |
+
|
377 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
378 |
+
|
379 |
+
training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
380 |
+
training_accuracy = round(training_accuracy, 4)
|
381 |
+
logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format(
|
382 |
+
idx_step, namespace, training_loss, training_accuracy
|
383 |
+
))
|
384 |
+
metrics[namespace] = {
|
385 |
+
"training_loss": training_loss,
|
386 |
+
"training_accuracy": training_accuracy,
|
387 |
+
}
|
388 |
+
namespace_to_total_loss = defaultdict(float)
|
389 |
+
namespace_to_total_examples = defaultdict(int)
|
390 |
+
|
391 |
+
# evaluation
|
392 |
+
for namespace in namespaces:
|
393 |
+
valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
|
394 |
+
|
395 |
+
model = namespace_to_classifier[namespace]
|
396 |
+
categorical_accuracy = namespace_to_categorical_accuracy[namespace]
|
397 |
+
|
398 |
+
model.eval()
|
399 |
+
|
400 |
+
total_loss = 0
|
401 |
+
total_examples = 0
|
402 |
+
for step, batch in enumerate(valid_data_loader):
|
403 |
+
input_ids, label_ids = batch
|
404 |
+
input_ids = input_ids.to(device)
|
405 |
+
label_ids: torch.LongTensor = label_ids.to(device).long()
|
406 |
+
|
407 |
+
with torch.no_grad():
|
408 |
+
logits = model.forward(input_ids)
|
409 |
+
loss = focal_loss.forward(logits, label_ids.view(-1))
|
410 |
+
categorical_accuracy(logits, label_ids)
|
411 |
+
|
412 |
+
total_loss += loss.item()
|
413 |
+
total_examples += input_ids.size(0)
|
414 |
+
|
415 |
+
evaluation_loss = total_loss / total_examples
|
416 |
+
evaluation_loss = round(evaluation_loss, 4)
|
417 |
+
evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
|
418 |
+
evaluation_accuracy = round(evaluation_accuracy, 4)
|
419 |
+
logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
|
420 |
+
idx_step, namespace, evaluation_loss, evaluation_accuracy
|
421 |
+
))
|
422 |
+
metrics[namespace] = {
|
423 |
+
"evaluation_loss": evaluation_loss,
|
424 |
+
"evaluation_accuracy": evaluation_accuracy,
|
425 |
+
}
|
426 |
+
|
427 |
+
# update ratio
|
428 |
+
min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()])
|
429 |
+
max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()])
|
430 |
+
width = max_accuracy - min_accuracy
|
431 |
+
for namespace, metric in metrics.items():
|
432 |
+
evaluation_accuracy = metric["evaluation_accuracy"]
|
433 |
+
radio = (max_accuracy - evaluation_accuracy) / width * max_radio
|
434 |
+
radio = int(radio)
|
435 |
+
namespace_to_ratio[namespace] = radio
|
436 |
+
|
437 |
+
msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()])
|
438 |
+
logger.info("namespace to ratio: {}".format(msg))
|
439 |
+
|
440 |
+
# save path
|
441 |
+
step_dir = serialization_dir / "step-{}".format(idx_step)
|
442 |
+
step_dir.mkdir(parents=True, exist_ok=False)
|
443 |
+
|
444 |
+
# save models
|
445 |
+
wave_encoder_filename = step_dir / "wave_encoder.pt"
|
446 |
+
torch.save(wave_encoder.state_dict(), wave_encoder_filename)
|
447 |
+
for namespace in namespaces:
|
448 |
+
cls_head_filename = step_dir / "{}.pt".format(namespace)
|
449 |
+
cls_head = namespace_to_cls_heads[namespace]
|
450 |
+
torch.save(cls_head.state_dict(), cls_head_filename)
|
451 |
+
|
452 |
+
model_list.append(step_dir)
|
453 |
+
if len(model_list) >= args.num_serialized_models_to_keep:
|
454 |
+
model_to_delete: Path = model_list.pop(0)
|
455 |
+
shutil.rmtree(model_to_delete.as_posix())
|
456 |
+
|
457 |
+
# save metric
|
458 |
+
this_accuracy = metrics["global_labels"]["evaluation_accuracy"]
|
459 |
+
if best_accuracy is None:
|
460 |
+
best_idx_step = idx_step
|
461 |
+
best_accuracy = this_accuracy
|
462 |
+
elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy:
|
463 |
+
best_idx_step = idx_step
|
464 |
+
best_accuracy = this_accuracy
|
465 |
+
else:
|
466 |
+
pass
|
467 |
+
|
468 |
+
metrics_filename = step_dir / "metrics_epoch.json"
|
469 |
+
metrics.update({
|
470 |
+
"idx_step": idx_step,
|
471 |
+
"best_idx_step": best_idx_step,
|
472 |
+
})
|
473 |
+
with open(metrics_filename, "w", encoding="utf-8") as f:
|
474 |
+
json.dump(metrics, f, indent=4, ensure_ascii=False)
|
475 |
+
|
476 |
+
# save best
|
477 |
+
best_dir = serialization_dir / "best"
|
478 |
+
if best_idx_step == idx_step:
|
479 |
+
if best_dir.exists():
|
480 |
+
shutil.rmtree(best_dir)
|
481 |
+
shutil.copytree(step_dir, best_dir)
|
482 |
+
|
483 |
+
# early stop
|
484 |
+
early_stop_flag = False
|
485 |
+
if best_idx_step == idx_step:
|
486 |
+
patience_count = 0
|
487 |
+
else:
|
488 |
+
patience_count += 1
|
489 |
+
if patience_count >= args.patience:
|
490 |
+
early_stop_flag = True
|
491 |
+
|
492 |
+
# early stop
|
493 |
+
if early_stop_flag:
|
494 |
+
break
|
495 |
+
return
|
496 |
+
|
497 |
+
|
498 |
+
if __name__ == "__main__":
|
499 |
+
main()
|
examples/vm_sound_classification8/stop.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
|
install.sh
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# bash install.sh --stage 2 --stop_stage 2 --system_version centos
|
4 |
+
|
5 |
+
|
6 |
+
python_version=3.8.10
|
7 |
+
system_version="centos";
|
8 |
+
|
9 |
+
verbose=true;
|
10 |
+
stage=-1
|
11 |
+
stop_stage=0
|
12 |
+
|
13 |
+
|
14 |
+
# parse options
|
15 |
+
while true; do
|
16 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
17 |
+
case "$1" in
|
18 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
19 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
20 |
+
old_value="(eval echo \\$$name)";
|
21 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
22 |
+
was_bool=true;
|
23 |
+
else
|
24 |
+
was_bool=false;
|
25 |
+
fi
|
26 |
+
|
27 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
28 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
29 |
+
eval "${name}=\"$2\"";
|
30 |
+
|
31 |
+
# Check that Boolean-valued arguments are really Boolean.
|
32 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
33 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
34 |
+
exit 1;
|
35 |
+
fi
|
36 |
+
shift 2;
|
37 |
+
;;
|
38 |
+
|
39 |
+
*) break;
|
40 |
+
esac
|
41 |
+
done
|
42 |
+
|
43 |
+
work_dir="$(pwd)"
|
44 |
+
|
45 |
+
|
46 |
+
if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
|
47 |
+
$verbose && echo "stage 1: install python"
|
48 |
+
cd "${work_dir}" || exit 1;
|
49 |
+
|
50 |
+
sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
|
51 |
+
fi
|
52 |
+
|
53 |
+
|
54 |
+
if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
|
55 |
+
$verbose && echo "stage 2: create virtualenv"
|
56 |
+
|
57 |
+
# /usr/local/python-3.6.5/bin/virtualenv vm_sound_classification
|
58 |
+
# source /data/local/bin/vm_sound_classification/bin/activate
|
59 |
+
/usr/local/python-${python_version}/bin/pip3 install virtualenv
|
60 |
+
mkdir -p /data/local/bin
|
61 |
+
cd /data/local/bin || exit 1;
|
62 |
+
/usr/local/python-${python_version}/bin/virtualenv vm_sound_classification
|
63 |
+
|
64 |
+
fi
|
main.py
ADDED
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import argparse
|
4 |
+
from functools import lru_cache
|
5 |
+
from pathlib import Path
|
6 |
+
import platform
|
7 |
+
import shutil
|
8 |
+
import tempfile
|
9 |
+
import zipfile
|
10 |
+
from typing import Tuple
|
11 |
+
|
12 |
+
import gradio as gr
|
13 |
+
from huggingface_hub import snapshot_download
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from project_settings import environment, project_path
|
18 |
+
from toolbox.torch.utils.data.vocabulary import Vocabulary
|
19 |
+
|
20 |
+
|
21 |
+
def get_args():
|
22 |
+
parser = argparse.ArgumentParser()
|
23 |
+
parser.add_argument(
|
24 |
+
"--examples_dir",
|
25 |
+
# default=(project_path / "data").as_posix(),
|
26 |
+
default=(project_path / "data/examples").as_posix(),
|
27 |
+
type=str
|
28 |
+
)
|
29 |
+
parser.add_argument(
|
30 |
+
"--models_repo_id",
|
31 |
+
default="qgyd2021/vm_sound_classification",
|
32 |
+
type=str
|
33 |
+
)
|
34 |
+
parser.add_argument(
|
35 |
+
"--trained_model_dir",
|
36 |
+
default=(project_path / "trained_models").as_posix(),
|
37 |
+
type=str
|
38 |
+
)
|
39 |
+
parser.add_argument(
|
40 |
+
"--hf_token",
|
41 |
+
default=environment.get("hf_token"),
|
42 |
+
type=str,
|
43 |
+
)
|
44 |
+
parser.add_argument(
|
45 |
+
"--server_port",
|
46 |
+
default=environment.get("server_port", 7860),
|
47 |
+
type=int
|
48 |
+
)
|
49 |
+
|
50 |
+
args = parser.parse_args()
|
51 |
+
return args
|
52 |
+
|
53 |
+
|
54 |
+
@lru_cache(maxsize=100)
|
55 |
+
def load_model(model_file: Path):
|
56 |
+
with zipfile.ZipFile(model_file, "r") as f_zip:
|
57 |
+
out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
|
58 |
+
if out_root.exists():
|
59 |
+
shutil.rmtree(out_root.as_posix())
|
60 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
61 |
+
f_zip.extractall(path=out_root)
|
62 |
+
|
63 |
+
tgt_path = out_root / model_file.stem
|
64 |
+
jit_model_file = tgt_path / "trace_model.zip"
|
65 |
+
vocab_path = tgt_path / "vocabulary"
|
66 |
+
|
67 |
+
vocabulary = Vocabulary.from_files(vocab_path.as_posix())
|
68 |
+
|
69 |
+
with open(jit_model_file.as_posix(), "rb") as f:
|
70 |
+
model = torch.jit.load(f)
|
71 |
+
model.eval()
|
72 |
+
|
73 |
+
shutil.rmtree(tgt_path)
|
74 |
+
|
75 |
+
d = {
|
76 |
+
"model": model,
|
77 |
+
"vocabulary": vocabulary
|
78 |
+
}
|
79 |
+
return d
|
80 |
+
|
81 |
+
|
82 |
+
def click_button(audio: np.ndarray,
|
83 |
+
model_name: str,
|
84 |
+
ground_true: str) -> Tuple[str, float]:
|
85 |
+
|
86 |
+
sample_rate, signal = audio
|
87 |
+
|
88 |
+
model_file = "trained_models/{}.zip".format(model_name)
|
89 |
+
model_file = Path(model_file)
|
90 |
+
d = load_model(model_file)
|
91 |
+
|
92 |
+
model = d["model"]
|
93 |
+
vocabulary = d["vocabulary"]
|
94 |
+
|
95 |
+
inputs = signal / (1 << 15)
|
96 |
+
inputs = torch.tensor(inputs, dtype=torch.float32)
|
97 |
+
inputs = torch.unsqueeze(inputs, dim=0)
|
98 |
+
|
99 |
+
with torch.no_grad():
|
100 |
+
logits = model.forward(inputs)
|
101 |
+
probs = torch.nn.functional.softmax(logits, dim=-1)
|
102 |
+
label_idx = torch.argmax(probs, dim=-1)
|
103 |
+
|
104 |
+
label_idx = label_idx.cpu()
|
105 |
+
probs = probs.cpu()
|
106 |
+
|
107 |
+
label_idx = label_idx.numpy()[0]
|
108 |
+
prob = probs.numpy()[0][label_idx]
|
109 |
+
|
110 |
+
label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
|
111 |
+
|
112 |
+
return label_str, round(prob, 4)
|
113 |
+
|
114 |
+
|
115 |
+
def main():
|
116 |
+
args = get_args()
|
117 |
+
|
118 |
+
examples_dir = Path(args.examples_dir)
|
119 |
+
trained_model_dir = Path(args.trained_model_dir)
|
120 |
+
|
121 |
+
# download models
|
122 |
+
if not trained_model_dir.exists():
|
123 |
+
trained_model_dir.mkdir(parents=True, exist_ok=True)
|
124 |
+
_ = snapshot_download(
|
125 |
+
repo_id=args.models_repo_id,
|
126 |
+
local_dir=trained_model_dir.as_posix(),
|
127 |
+
token=args.hf_token,
|
128 |
+
)
|
129 |
+
|
130 |
+
# examples
|
131 |
+
example_zip_file = trained_model_dir / "examples.zip"
|
132 |
+
with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
|
133 |
+
out_root = examples_dir
|
134 |
+
if out_root.exists():
|
135 |
+
shutil.rmtree(out_root.as_posix())
|
136 |
+
out_root.mkdir(parents=True, exist_ok=True)
|
137 |
+
f_zip.extractall(path=out_root)
|
138 |
+
|
139 |
+
# models
|
140 |
+
model_choices = list()
|
141 |
+
for filename in trained_model_dir.glob("*.zip"):
|
142 |
+
model_name = filename.stem
|
143 |
+
if model_name == "examples":
|
144 |
+
continue
|
145 |
+
model_choices.append(model_name)
|
146 |
+
model_choices = list(sorted(model_choices))
|
147 |
+
|
148 |
+
# examples
|
149 |
+
examples = list()
|
150 |
+
for filename in examples_dir.glob("**/*/*.wav"):
|
151 |
+
label = filename.parts[-2]
|
152 |
+
|
153 |
+
examples.append([
|
154 |
+
filename.as_posix(),
|
155 |
+
model_choices[0],
|
156 |
+
label
|
157 |
+
])
|
158 |
+
|
159 |
+
# ui
|
160 |
+
brief_description = """
|
161 |
+
国际语音智能外呼系统, 电话声音分类, 8000, int16.
|
162 |
+
"""
|
163 |
+
|
164 |
+
# ui
|
165 |
+
with gr.Blocks() as blocks:
|
166 |
+
gr.Markdown(value=brief_description)
|
167 |
+
|
168 |
+
with gr.Row():
|
169 |
+
with gr.Column(scale=3):
|
170 |
+
c_audio = gr.Audio(label="audio")
|
171 |
+
with gr.Row():
|
172 |
+
with gr.Column(scale=3):
|
173 |
+
c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
|
174 |
+
with gr.Column(scale=3):
|
175 |
+
c_ground_true = gr.Textbox(label="ground_true")
|
176 |
+
|
177 |
+
c_button = gr.Button("run", variant="primary")
|
178 |
+
with gr.Column(scale=3):
|
179 |
+
c_label = gr.Textbox(label="label")
|
180 |
+
c_probability = gr.Number(label="probability")
|
181 |
+
|
182 |
+
gr.Examples(
|
183 |
+
examples,
|
184 |
+
inputs=[c_audio, c_model_name, c_ground_true],
|
185 |
+
outputs=[c_label, c_probability],
|
186 |
+
fn=click_button,
|
187 |
+
examples_per_page=5,
|
188 |
+
)
|
189 |
+
|
190 |
+
c_button.click(
|
191 |
+
click_button,
|
192 |
+
inputs=[c_audio, c_model_name, c_ground_true],
|
193 |
+
outputs=[c_label, c_probability],
|
194 |
+
)
|
195 |
+
|
196 |
+
# http://127.0.0.1:7864/
|
197 |
+
blocks.queue().launch(
|
198 |
+
share=False if platform.system() == "Windows" else False,
|
199 |
+
server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
|
200 |
+
server_port=args.server_port
|
201 |
+
)
|
202 |
+
return
|
203 |
+
|
204 |
+
|
205 |
+
if __name__ == "__main__":
|
206 |
+
main()
|
project_settings.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/python3
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
import os
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
from toolbox.os.environment import EnvironmentManager
|
7 |
+
|
8 |
+
|
9 |
+
project_path = os.path.abspath(os.path.dirname(__file__))
|
10 |
+
project_path = Path(project_path)
|
11 |
+
|
12 |
+
environment = EnvironmentManager(
|
13 |
+
path=os.path.join(project_path, "dotenv"),
|
14 |
+
env=os.environ.get("environment", "dev"),
|
15 |
+
)
|
16 |
+
|
17 |
+
|
18 |
+
if __name__ == '__main__':
|
19 |
+
pass
|
requirements.txt
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==2.3.0
|
2 |
+
torchaudio==2.3.0
|
3 |
+
fsspec==2024.5.0
|
4 |
+
librosa==0.10.2
|
5 |
+
pandas==2.0.3
|
6 |
+
openpyxl==3.0.9
|
7 |
+
xlrd==1.2.0
|
8 |
+
tqdm==4.66.4
|
9 |
+
overrides==1.9.0
|
10 |
+
pyyaml==6.0.1
|
11 |
+
evaluate==0.4.2
|
12 |
+
gradio
|
13 |
+
python-dotenv==1.0.1
|
script/install_nvidia_driver.sh
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
#GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
|
3 |
+
#参考链接:
|
4 |
+
#https://blog.csdn.net/kingschan/article/details/19033595
|
5 |
+
#https://blog.csdn.net/HaixWang/article/details/90408538
|
6 |
+
#
|
7 |
+
#>>> yum install -y pciutils
|
8 |
+
#查看 linux 机器上是否有 GPU
|
9 |
+
#lspci |grep -i nvidia
|
10 |
+
#
|
11 |
+
#>>> lspci |grep -i nvidia
|
12 |
+
#00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
|
13 |
+
#
|
14 |
+
#
|
15 |
+
#NVIDIA 驱动程序下载
|
16 |
+
#先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
|
17 |
+
#再根据 gpu 版本下载安装对应的 nvidia 驱动
|
18 |
+
#
|
19 |
+
## pytorch 版本
|
20 |
+
#https://pytorch.org/get-started/locally/
|
21 |
+
#
|
22 |
+
## CUDA 下载 (好像不需要这个)
|
23 |
+
#https://developer.nvidia.com/cuda-toolkit-archive
|
24 |
+
#
|
25 |
+
## nvidia 驱动
|
26 |
+
#https://www.nvidia.cn/Download/index.aspx?lang=cn
|
27 |
+
#http://www.nvidia.com/Download/index.aspx
|
28 |
+
#
|
29 |
+
#在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
|
30 |
+
#产品类型:
|
31 |
+
#Data Center / Tesla
|
32 |
+
#产品系列:
|
33 |
+
#T-Series
|
34 |
+
#产品家族:
|
35 |
+
#Tesla T4
|
36 |
+
#操作系统:
|
37 |
+
#Linux 64-bit
|
38 |
+
#CUDA Toolkit:
|
39 |
+
#10.2
|
40 |
+
#语言:
|
41 |
+
#Chinese (Simpleified)
|
42 |
+
#
|
43 |
+
#
|
44 |
+
#>>> mkdir -p /data/tianxing
|
45 |
+
#>>> cd /data/tianxing
|
46 |
+
#>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
|
47 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
48 |
+
#
|
49 |
+
## 异常:
|
50 |
+
#ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
|
51 |
+
#Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
|
52 |
+
#[OK]
|
53 |
+
#
|
54 |
+
#For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
|
55 |
+
#[NO]
|
56 |
+
#
|
57 |
+
#ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
|
58 |
+
#page at www.nvidia.com.
|
59 |
+
#[OK]
|
60 |
+
#
|
61 |
+
## 参考链接:
|
62 |
+
#https://blog.csdn.net/kingschan/article/details/19033595
|
63 |
+
#
|
64 |
+
## 禁用原有的显卡驱动 nouveau
|
65 |
+
#>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
|
66 |
+
#>>> sudo dracut --force
|
67 |
+
## 重启
|
68 |
+
#>>> reboot
|
69 |
+
#
|
70 |
+
#>>> init 3
|
71 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
72 |
+
#
|
73 |
+
## 异常
|
74 |
+
#ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
|
75 |
+
#[OK]
|
76 |
+
#ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
|
77 |
+
#page at www.nvidia.com.
|
78 |
+
#[OK]
|
79 |
+
#
|
80 |
+
## 参考链接
|
81 |
+
## https://blog.csdn.net/HaixWang/article/details/90408538
|
82 |
+
#
|
83 |
+
#>>> uname -r
|
84 |
+
#3.10.0-1160.49.1.el7.x86_64
|
85 |
+
#>>> yum install kernel-devel kernel-headers -y
|
86 |
+
#>>> yum info kernel-devel kernel-headers
|
87 |
+
#>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
|
88 |
+
#>>> yum -y distro-sync
|
89 |
+
#
|
90 |
+
#>>> sh NVIDIA-Linux-x86_64-440.118.02.run
|
91 |
+
#
|
92 |
+
## 安装成功
|
93 |
+
#WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
|
94 |
+
#module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
|
95 |
+
#[OK]
|
96 |
+
#Install NVIDIA's 32-bit compatibility libraries?
|
97 |
+
#[YES]
|
98 |
+
#Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
|
99 |
+
#[OK]
|
100 |
+
#
|
101 |
+
#
|
102 |
+
## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
|
103 |
+
#>>> nvidia-smi
|
104 |
+
#Thu Mar 9 12:00:37 2023
|
105 |
+
#+-----------------------------------------------------------------------------+
|
106 |
+
#| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
|
107 |
+
#|-------------------------------+----------------------+----------------------+
|
108 |
+
#| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
|
109 |
+
#| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
|
110 |
+
#|===============================+======================+======================|
|
111 |
+
#| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
|
112 |
+
#| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
|
113 |
+
#+-------------------------------+----------------------+----------------------+
|
114 |
+
#
|
115 |
+
#+-----------------------------------------------------------------------------+
|
116 |
+
#| Processes: GPU Memory |
|
117 |
+
#| GPU PID Type Process name Usage |
|
118 |
+
#|=============================================================================|
|
119 |
+
#| No running processes found |
|
120 |
+
#+-----------------------------------------------------------------------------+
|
121 |
+
#
|
122 |
+
#
|
123 |
+
|
124 |
+
# params
|
125 |
+
stage=1
|
126 |
+
nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
|
127 |
+
|
128 |
+
# parse options
|
129 |
+
while true; do
|
130 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
131 |
+
case "$1" in
|
132 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
133 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
134 |
+
old_value="(eval echo \\$$name)";
|
135 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
136 |
+
was_bool=true;
|
137 |
+
else
|
138 |
+
was_bool=false;
|
139 |
+
fi
|
140 |
+
|
141 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
142 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
143 |
+
eval "${name}=\"$2\"";
|
144 |
+
|
145 |
+
# Check that Boolean-valued arguments are really Boolean.
|
146 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
147 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
148 |
+
exit 1;
|
149 |
+
fi
|
150 |
+
shift 2;
|
151 |
+
;;
|
152 |
+
|
153 |
+
*) break;
|
154 |
+
esac
|
155 |
+
done
|
156 |
+
|
157 |
+
echo "stage: ${stage}";
|
158 |
+
|
159 |
+
yum -y install wget
|
160 |
+
yum -y install sudo
|
161 |
+
|
162 |
+
if [ ${stage} -eq 0 ]; then
|
163 |
+
mkdir -p /data/dep
|
164 |
+
cd /data/dep || echo 1;
|
165 |
+
wget -P /data/dep ${nvidia_driver_filename}
|
166 |
+
|
167 |
+
echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
|
168 |
+
sudo dracut --force
|
169 |
+
# 重启
|
170 |
+
reboot
|
171 |
+
elif [ ${stage} -eq 1 ]; then
|
172 |
+
init 3
|
173 |
+
|
174 |
+
yum install -y kernel-devel kernel-headers
|
175 |
+
yum info kernel-devel kernel-headers
|
176 |
+
yum install -y "kernel-devel-uname-r == $(uname -r)"
|
177 |
+
yum -y distro-sync
|
178 |
+
|
179 |
+
cd /data/dep || echo 1;
|
180 |
+
|
181 |
+
# 安装时, 需要回车三下.
|
182 |
+
sh NVIDIA-Linux-x86_64-440.118.02.run
|
183 |
+
nvidia-smi
|
184 |
+
fi
|
script/install_python.sh
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
# 参数:
|
4 |
+
python_version="3.6.5";
|
5 |
+
system_version="centos";
|
6 |
+
|
7 |
+
|
8 |
+
# parse options
|
9 |
+
while true; do
|
10 |
+
[ -z "${1:-}" ] && break; # break if there are no arguments
|
11 |
+
case "$1" in
|
12 |
+
--*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
|
13 |
+
eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
|
14 |
+
old_value="(eval echo \\$$name)";
|
15 |
+
if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
|
16 |
+
was_bool=true;
|
17 |
+
else
|
18 |
+
was_bool=false;
|
19 |
+
fi
|
20 |
+
|
21 |
+
# Set the variable to the right value-- the escaped quotes make it work if
|
22 |
+
# the option had spaces, like --cmd "queue.pl -sync y"
|
23 |
+
eval "${name}=\"$2\"";
|
24 |
+
|
25 |
+
# Check that Boolean-valued arguments are really Boolean.
|
26 |
+
if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
|
27 |
+
echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
|
28 |
+
exit 1;
|
29 |
+
fi
|
30 |
+
shift 2;
|
31 |
+
;;
|
32 |
+
|
33 |
+
*) break;
|
34 |
+
esac
|
35 |
+
done
|
36 |
+
|
37 |
+
echo "python_version: ${python_version}";
|
38 |
+
echo "system_version: ${system_version}";
|
39 |
+
|
40 |
+
|
41 |
+
if [ ${system_version} = "centos" ]; then
|
42 |
+
# 安装 python 开发编译环境
|
43 |
+
yum -y groupinstall "Development tools"
|
44 |
+
yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
|
45 |
+
yum install libffi-devel -y
|
46 |
+
yum install -y wget
|
47 |
+
yum install -y make
|
48 |
+
|
49 |
+
mkdir -p /data/dep
|
50 |
+
cd /data/dep || exit 1;
|
51 |
+
if [ ! -e Python-${python_version}.tgz ]; then
|
52 |
+
wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
|
53 |
+
fi
|
54 |
+
|
55 |
+
cd /data/dep || exit 1;
|
56 |
+
if [ ! -d Python-${python_version} ]; then
|
57 |
+
tar -zxvf Python-${python_version}.tgz
|
58 |
+
cd /data/dep/Python-${python_version} || exit 1;
|
59 |
+
fi
|
60 |
+
|
61 |
+
mkdir /usr/local/python-${python_version}
|
62 |
+
./configure --prefix=/usr/local/python-${python_version}
|
63 |
+
make && make install
|
64 |
+
|
65 |
+
/usr/local/python-${python_version}/bin/python3 -V
|
66 |
+
/usr/local/python-${python_version}/bin/pip3 -V
|
67 |
+
|
68 |
+
rm -rf /usr/local/bin/python3
|
69 |
+
rm -rf /usr/local/bin/pip3
|
70 |
+
ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
|
71 |
+
ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
|
72 |
+
|
73 |
+
python3 -V
|
74 |
+
pip3 -V
|
75 |
+
|
76 |
+
elif [ ${system_version} = "ubuntu" ]; then
|
77 |
+
# 安装 python 开发编译环境
|
78 |
+
# https://zhuanlan.zhihu.com/p/506491209
|
79 |
+
|
80 |
+
# 刷新软件包目录
|
81 |
+
sudo apt update
|
82 |
+
# 列出当前可用的更新
|
83 |
+
sudo apt list --upgradable
|
84 |
+
# 如上一步提示有可以更新的项目,则执行更新
|
85 |
+
sudo apt -y upgrade
|
86 |
+
# 安装 GCC 编译器
|
87 |
+
sudo apt install gcc
|
88 |
+
# 检查安装是否成功
|
89 |
+
gcc -v
|
90 |
+
|
91 |
+
# 安装依赖
|
92 |
+
sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
|
93 |
+
|
94 |
+
mkdir -p /data/dep
|
95 |
+
cd /data/dep || exit 1;
|
96 |
+
if [ ! -e Python-${python_version}.tgz ]; then
|
97 |
+
# sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
|
98 |
+
sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
|
99 |
+
fi
|
100 |
+
|
101 |
+
cd /data/dep || exit 1;
|
102 |
+
if [ ! -d Python-${python_version} ]; then
|
103 |
+
# tar -zxvf Python-3.6.5.tgz
|
104 |
+
tar -zxvf Python-${python_version}.tgz
|
105 |
+
# cd /data/dep/Python-3.6.5
|
106 |
+
cd /data/dep/Python-${python_version} || exit 1;
|
107 |
+
fi
|
108 |
+
|
109 |
+
# mkdir /usr/local/python-3.6.5
|
110 |
+
mkdir /usr/local/python-${python_version}
|
111 |
+
|
112 |
+
# 检查依赖与配置编译
|
113 |
+
# sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
|
114 |
+
sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
|
115 |
+
cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
|
116 |
+
# sudo make -j 4
|
117 |
+
sudo make -j "${cpu_count}"
|
118 |
+
|
119 |
+
/usr/local/python-${python_version}/bin/python3 -V
|
120 |
+
/usr/local/python-${python_version}/bin/pip3 -V
|
121 |
+
|
122 |
+
rm -rf /usr/local/bin/python3
|
123 |
+
rm -rf /usr/local/bin/pip3
|
124 |
+
ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
|
125 |
+
ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
|
126 |
+
|
127 |
+
python3 -V
|
128 |
+
pip3 -V
|
129 |
+
fi
|