HoneyTian commited on
Commit
bfa885e
·
0 Parent(s):

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +35 -0
  2. .gitignore +19 -0
  3. Dockerfile +21 -0
  4. README.md +11 -0
  5. examples/sample_filter/bad_case_find.py +84 -0
  6. examples/sample_filter/correction.py +70 -0
  7. examples/sample_filter/find_label_error_wav.py +77 -0
  8. examples/sample_filter/test2.py +78 -0
  9. examples/sample_filter/wav_find_by_task_excel.py +92 -0
  10. examples/vm_sound_classification/requirements.txt +10 -0
  11. examples/vm_sound_classification/run.sh +197 -0
  12. examples/vm_sound_classification/run_batch.sh +268 -0
  13. examples/vm_sound_classification/step_1_prepare_data.py +194 -0
  14. examples/vm_sound_classification/step_2_make_vocabulary.py +51 -0
  15. examples/vm_sound_classification/step_3_train_model.py +367 -0
  16. examples/vm_sound_classification/step_4_evaluation_model.py +128 -0
  17. examples/vm_sound_classification/step_5_export_models.py +106 -0
  18. examples/vm_sound_classification/step_6_infer.py +91 -0
  19. examples/vm_sound_classification/step_7_test_model.py +93 -0
  20. examples/vm_sound_classification/stop.sh +3 -0
  21. examples/vm_sound_classification/yaml/conv2d-classifier-2-ch16.yaml +45 -0
  22. examples/vm_sound_classification/yaml/conv2d-classifier-2-ch32.yaml +45 -0
  23. examples/vm_sound_classification/yaml/conv2d-classifier-2-ch4.yaml +45 -0
  24. examples/vm_sound_classification/yaml/conv2d-classifier-2-ch8.yaml +45 -0
  25. examples/vm_sound_classification/yaml/conv2d-classifier-3-ch16.yaml +45 -0
  26. examples/vm_sound_classification/yaml/conv2d-classifier-3-ch32.yaml +45 -0
  27. examples/vm_sound_classification/yaml/conv2d-classifier-3-ch4.yaml +45 -0
  28. examples/vm_sound_classification/yaml/conv2d-classifier-3-ch8.yaml +45 -0
  29. examples/vm_sound_classification/yaml/conv2d-classifier-4-ch16.yaml +45 -0
  30. examples/vm_sound_classification/yaml/conv2d-classifier-4-ch32.yaml +45 -0
  31. examples/vm_sound_classification/yaml/conv2d-classifier-4-ch4.yaml +45 -0
  32. examples/vm_sound_classification/yaml/conv2d-classifier-4-ch8.yaml +45 -0
  33. examples/vm_sound_classification/yaml/conv2d-classifier-8-ch16.yaml +45 -0
  34. examples/vm_sound_classification/yaml/conv2d-classifier-8-ch32.yaml +45 -0
  35. examples/vm_sound_classification/yaml/conv2d-classifier-8-ch4.yaml +45 -0
  36. examples/vm_sound_classification/yaml/conv2d-classifier-8-ch8.yaml +45 -0
  37. examples/vm_sound_classification8/requirements.txt +9 -0
  38. examples/vm_sound_classification8/run.sh +157 -0
  39. examples/vm_sound_classification8/step_1_prepare_data.py +156 -0
  40. examples/vm_sound_classification8/step_2_make_vocabulary.py +69 -0
  41. examples/vm_sound_classification8/step_3_train_global_model.py +328 -0
  42. examples/vm_sound_classification8/step_4_train_country_model.py +349 -0
  43. examples/vm_sound_classification8/step_5_train_union.py +499 -0
  44. examples/vm_sound_classification8/stop.sh +3 -0
  45. install.sh +64 -0
  46. main.py +206 -0
  47. project_settings.py +19 -0
  48. requirements.txt +13 -0
  49. script/install_nvidia_driver.sh +184 -0
  50. script/install_python.sh +129 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ .git/
3
+ .idea/
4
+
5
+ **/file_dir
6
+ **/flagged/
7
+ **/log/
8
+ **/logs/
9
+ **/__pycache__/
10
+
11
+ /data/
12
+ /docs/
13
+ /dotenv/
14
+ /examples/**/*.wav
15
+ /trained_models/
16
+ /temp/
17
+
18
+ #**/*.wav
19
+ **/*.xlsx
Dockerfile ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.8
2
+
3
+ WORKDIR /code
4
+
5
+ COPY . /code
6
+
7
+ RUN pip install --upgrade pip
8
+ RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
9
+
10
+ RUN useradd -m -u 1000 user
11
+
12
+ USER user
13
+
14
+ ENV HOME=/home/user \
15
+ PATH=/home/user/.local/bin:$PATH
16
+
17
+ WORKDIR $HOME/app
18
+
19
+ COPY --chown=user . $HOME/app
20
+
21
+ CMD ["python3", "main.py"]
README.md ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: VM Sound Classification
3
+ emoji: 🐢
4
+ colorFrom: purple
5
+ colorTo: blue
6
+ sdk: docker
7
+ pinned: false
8
+ license: apache-2.0
9
+ ---
10
+
11
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
examples/sample_filter/bad_case_find.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import shutil
6
+
7
+ from gradio_client import Client, handle_file
8
+ from tqdm import tqdm
9
+
10
+
11
+ def get_args():
12
+ parser = argparse.ArgumentParser()
13
+ parser.add_argument(
14
+ "--data_dir",
15
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\data",
16
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\us-3",
17
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
18
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\id",
19
+ type=str
20
+ )
21
+ parser.add_argument(
22
+ "--keep_dir",
23
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\keep",
24
+ type=str
25
+ )
26
+ parser.add_argument(
27
+ "--trash_dir",
28
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\trash",
29
+ type=str
30
+ )
31
+ args = parser.parse_args()
32
+ return args
33
+
34
+
35
+ def main():
36
+ args = get_args()
37
+
38
+ data_dir = Path(args.data_dir)
39
+ keep_dir = Path(args.keep_dir)
40
+ keep_dir.mkdir(parents=True, exist_ok=True)
41
+ # trash_dir = Path(args.trash_dir)
42
+ # trash_dir.mkdir(parents=True, exist_ok=True)
43
+
44
+ client = Client("http://127.0.0.1:7864/")
45
+
46
+ for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
47
+ # if idx < 400:
48
+ # continue
49
+ filename = filename.as_posix()
50
+
51
+ label1, prob1 = client.predict(
52
+ audio=handle_file(filename),
53
+ # model_name="vm_sound_classification8-ch32",
54
+ model_name="voicemail-en-ph-2-ch4",
55
+ ground_true="Hello!!",
56
+ api_name="/click_button"
57
+ )
58
+ prob1 = float(prob1)
59
+
60
+ label2, prob2 = client.predict(
61
+ audio=handle_file(filename),
62
+ # model_name="vm_sound_classification8-ch32",
63
+ model_name="sound-8-ch32",
64
+ ground_true="Hello!!",
65
+ api_name="/click_button"
66
+ )
67
+ prob2 = float(prob2)
68
+
69
+ if label1 == "voicemail" and label2 in ("voicemail", "bell") and prob1 > 0.6:
70
+ pass
71
+ elif label1 == "non_voicemail" and label2 not in ("voicemail", "bell") and prob1 > 0.6:
72
+ pass
73
+ else:
74
+ print(f"label1: {label1}, prob1: {prob1}, label2: {label2}, prob2: {prob2}")
75
+ shutil.move(
76
+ filename,
77
+ keep_dir.as_posix(),
78
+ )
79
+ # exit(0)
80
+ return
81
+
82
+
83
+ if __name__ == '__main__':
84
+ main()
examples/sample_filter/correction.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import shutil
6
+
7
+ from gradio_client import Client, handle_file
8
+ from tqdm import tqdm
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--data_dir",
17
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-PH\wav_finished",
18
+ type=str
19
+ )
20
+ parser.add_argument(
21
+ "--correction_dir",
22
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\correction",
23
+ type=str
24
+ )
25
+ args = parser.parse_args()
26
+ return args
27
+
28
+
29
+ def main():
30
+ args = get_args()
31
+
32
+ data_dir = Path(args.data_dir)
33
+ correction_dir = Path(args.correction_dir)
34
+ correction_dir.mkdir(parents=True, exist_ok=True)
35
+
36
+ client = Client("http://127.0.0.1:7864/")
37
+
38
+ for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
39
+ # if idx < 200:
40
+ # continue
41
+ ground_truth = filename.parts[-2]
42
+ filename = filename.as_posix()
43
+
44
+ label, prob = client.predict(
45
+ audio=handle_file(filename),
46
+ model_name="voicemail-en-ph-2-ch32",
47
+ ground_true="Hello!!",
48
+ api_name="/click_button"
49
+ )
50
+ prob = float(prob)
51
+
52
+ if label == "voicemail" and ground_truth in ("voicemail", "bell"):
53
+ pass
54
+ elif label == "non_voicemail" and ground_truth not in ("voicemail", "bell"):
55
+ pass
56
+ else:
57
+ print(f"ground_truth: {ground_truth}, label: {label}, prob: {prob}")
58
+
59
+ tgt_dir = correction_dir / ground_truth
60
+ tgt_dir.mkdir(parents=True, exist_ok=True)
61
+ shutil.move(
62
+ filename,
63
+ tgt_dir.as_posix(),
64
+ )
65
+
66
+ return
67
+
68
+
69
+ if __name__ == '__main__':
70
+ main()
examples/sample_filter/find_label_error_wav.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import shutil
6
+
7
+ from gradio_client import Client, handle_file
8
+ from tqdm import tqdm
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--data_dir",
17
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\wav_finished",
18
+ type=str
19
+ )
20
+ parser.add_argument(
21
+ "--keep_dir",
22
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\keep",
23
+ type=str
24
+ )
25
+ parser.add_argument(
26
+ "--trash_dir",
27
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\wav_finished\en-US\trash",
28
+ type=str
29
+ )
30
+ args = parser.parse_args()
31
+ return args
32
+
33
+
34
+ def main():
35
+ args = get_args()
36
+
37
+ data_dir = Path(args.data_dir)
38
+ keep_dir = Path(args.keep_dir)
39
+ keep_dir.mkdir(parents=True, exist_ok=True)
40
+ trash_dir = Path(args.trash_dir)
41
+ # trash_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ client = Client("http://127.0.0.1:7864/")
44
+
45
+ for idx, filename in tqdm(enumerate(data_dir.glob("**/*.wav"))):
46
+ # if idx < 200:
47
+ # continue
48
+ ground_truth = filename.parts[-2]
49
+ filename = filename.as_posix()
50
+
51
+ label1, prob1 = client.predict(
52
+ audio=handle_file(filename),
53
+ # model_name="vm_sound_classification8-ch32",
54
+ model_name="voicemail-en-us-2-ch32",
55
+ ground_true="Hello!!",
56
+ api_name="/click_button"
57
+ )
58
+ prob1 = float(prob1)
59
+ print(f"label: {label1}, prob: {prob1}, ground_truth: {ground_truth}")
60
+
61
+ if label1 == "voicemail" and ground_truth in ("bell", "voicemail") and prob1 > 0.65:
62
+ pass
63
+ elif label1 == "non_voicemail" and ground_truth not in ("bell", "voicemail") and prob1 > 0.65:
64
+ pass
65
+ else:
66
+ tgt = keep_dir / ground_truth
67
+ tgt.mkdir(parents=True, exist_ok=True)
68
+ shutil.move(
69
+ filename,
70
+ tgt.as_posix(),
71
+ )
72
+
73
+ return
74
+
75
+
76
+ if __name__ == '__main__':
77
+ main()
examples/sample_filter/test2.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from pathlib import Path
5
+ import shutil
6
+
7
+ from gradio_client import Client, handle_file
8
+ from tqdm import tqdm
9
+
10
+ from project_settings import project_path
11
+
12
+
13
+ def get_args():
14
+ parser = argparse.ArgumentParser()
15
+ parser.add_argument(
16
+ "--data_dir",
17
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\data-1",
18
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-2\temp\VoiceAppVoicemailDetection-1",
19
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-3\temp\VoiceAppVoicemailDetection-1",
20
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-4\temp\VoiceAppVoicemailDetection-1",
21
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
22
+ type=str
23
+ )
24
+ parser.add_argument(
25
+ "--keep_dir",
26
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\keep-3",
27
+ type=str
28
+ )
29
+ parser.add_argument(
30
+ "--trash_dir",
31
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\trash",
32
+ type=str
33
+ )
34
+ args = parser.parse_args()
35
+ return args
36
+
37
+
38
+ def main():
39
+ args = get_args()
40
+
41
+ data_dir = Path(args.data_dir)
42
+ keep_dir = Path(args.keep_dir)
43
+ keep_dir.mkdir(parents=True, exist_ok=True)
44
+ trash_dir = Path(args.trash_dir)
45
+ trash_dir.mkdir(parents=True, exist_ok=True)
46
+
47
+ client = Client("http://127.0.0.1:7864/")
48
+
49
+ for idx, filename in tqdm(enumerate(data_dir.glob("*.wav"))):
50
+ if idx < 200:
51
+ continue
52
+ filename = filename.as_posix()
53
+
54
+ label1, prob1 = client.predict(
55
+ audio=handle_file(filename),
56
+ # model_name="vm_sound_classification8-ch32",
57
+ model_name="voicemail-ms-my-2-ch32",
58
+ ground_true="Hello!!",
59
+ api_name="/click_button"
60
+ )
61
+ prob1 = float(prob1)
62
+ print(f"label: {label1}, prob: {prob1}")
63
+
64
+ if label1 == "voicemail" and prob1 < 0.95:
65
+ shutil.move(
66
+ filename,
67
+ keep_dir.as_posix(),
68
+ )
69
+ elif label1 != "voicemail" and prob1 < 0.85:
70
+ shutil.move(
71
+ filename,
72
+ keep_dir.as_posix(),
73
+ )
74
+ return
75
+
76
+
77
+ if __name__ == '__main__':
78
+ main()
examples/sample_filter/wav_find_by_task_excel.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+
8
+ import pandas as pd
9
+ from gradio_client import Client, handle_file
10
+ from tqdm import tqdm
11
+
12
+ from project_settings import project_path
13
+
14
+
15
+ task_file_str = """
16
+ task_DcTask_1_PH_LIVE_20250328_20250328-1.xlsx
17
+ task_DcTask_1_PH_LIVE_20250329_20250329-1.xlsx
18
+ task_DcTask_1_PH_LIVE_20250331_20250331-1.xlsx
19
+ task_DcTask_3_PH_LIVE_20250328_20250328-1.xlsx
20
+ task_DcTask_3_PH_LIVE_20250331_20250331-1.xlsx
21
+ task_DcTask_9_PH_LIVE_20250329_20250329-1.xlsx
22
+ task_DcTask_9_PH_LIVE_20250331_20250331-1.xlsx
23
+ """
24
+
25
+
26
+ def get_args():
27
+ parser = argparse.ArgumentParser()
28
+ parser.add_argument(
29
+ "--task_file_str",
30
+ default=task_file_str,
31
+ type=str
32
+ )
33
+ parser.add_argument(
34
+ "--wav_dir",
35
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\phl",
36
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-2\temp\VoiceAppVoicemailDetection-1",
37
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-3\temp\VoiceAppVoicemailDetection-1",
38
+ # default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\temp-4\temp\VoiceAppVoicemailDetection-1",
39
+ type=str
40
+ )
41
+ parser.add_argument(
42
+ "--output_dir",
43
+ default=r"E:\Users\tianx\HuggingDatasets\vm_sound_classification\data\transfer",
44
+ type=str
45
+ )
46
+ args = parser.parse_args()
47
+ return args
48
+
49
+
50
+ def main():
51
+ args = get_args()
52
+ wav_dir = Path(args.wav_dir)
53
+ output_dir = Path(args.output_dir)
54
+ output_dir.mkdir(parents=True, exist_ok=True)
55
+
56
+ task_file_list = task_file_str.split("\n")
57
+ task_file_list = [task_file for task_file in task_file_list if len(task_file.strip()) != 0]
58
+ print(f"task_file_list: {task_file_list}")
59
+
60
+ for task_file in task_file_list:
61
+ df = pd.read_excel(task_file)
62
+
63
+ transfer_set = set()
64
+ for i, row in df.iterrows():
65
+ call_id = row["通话ID"]
66
+ intent_str = row["意向标签"]
67
+ if intent_str == "Connection - Transferred to agent":
68
+ transfer_set.add(call_id)
69
+ if intent_str == "Connection - No human voice detected":
70
+ transfer_set.add(call_id)
71
+
72
+ print(f"transfer count: {len(transfer_set)}")
73
+
74
+ for idx, filename in tqdm(enumerate(wav_dir.glob("**/*.wav"))):
75
+
76
+ basename = filename.stem
77
+ call_id, _, _, _ = basename.split("_")
78
+
79
+ if call_id not in transfer_set:
80
+ continue
81
+
82
+ print(filename.as_posix())
83
+ shutil.move(
84
+ filename.as_posix(),
85
+ output_dir.as_posix()
86
+ )
87
+
88
+ return
89
+
90
+
91
+ if __name__ == '__main__':
92
+ main()
examples/vm_sound_classification/requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.13.1
2
+ torchaudio==0.13.1
3
+ fsspec==2022.1.0
4
+ librosa==0.9.2
5
+ pandas==1.1.5
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.64.1
9
+ overrides==1.9.0
10
+ pyyaml==6.0.1
examples/vm_sound_classification/run.sh ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 0 --stop_stage 1 --system_version windows --file_folder_name file_dir --final_model_name sound-4-ch32 \
6
+ --filename_patterns "E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/wav_finished/en-US/wav_finished/*/*.wav \
7
+ E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/id-ID/wav_finished/*/*.wav" \
8
+ --label_plan 4
9
+
10
+ sh run.sh --stage 2 --stop_stage 2 --system_version windows --file_folder_name file_dir --final_model_name sound-2-ch32 \
11
+ --filename_patterns "E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/wav_finished/en-US/wav_finished/*/*.wav \
12
+ E:/Users/tianx/HuggingDatasets/vm_sound_classification/data/wav_finished/id-ID/wav_finished/*/*.wav" \
13
+ --label_plan 4
14
+
15
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch32 \
16
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
17
+ --label_plan 3 \
18
+ --config_file "yaml/conv2d-classifier-3-ch4.yaml"
19
+
20
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch32 \
21
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
22
+ --label_plan 2-voicemail \
23
+ --config_file "yaml/conv2d-classifier-2-ch32.yaml"
24
+
25
+ END
26
+
27
+
28
+ # params
29
+ system_version="windows";
30
+ verbose=true;
31
+ stage=0 # start from 0 if you need to start from data preparation
32
+ stop_stage=9
33
+
34
+ work_dir="$(pwd)"
35
+ file_folder_name=file_folder_name
36
+ final_model_name=final_model_name
37
+ filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
38
+ label_plan=4
39
+ config_file="yaml/conv2d-classifier-2-ch4.yaml"
40
+ pretrained_model=null
41
+ nohup_name=nohup.out
42
+
43
+ country=en-US
44
+
45
+ # model params
46
+ batch_size=64
47
+ max_epochs=200
48
+ save_top_k=10
49
+ patience=5
50
+
51
+
52
+ # parse options
53
+ while true; do
54
+ [ -z "${1:-}" ] && break; # break if there are no arguments
55
+ case "$1" in
56
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
57
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
58
+ old_value="(eval echo \\$$name)";
59
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
60
+ was_bool=true;
61
+ else
62
+ was_bool=false;
63
+ fi
64
+
65
+ # Set the variable to the right value-- the escaped quotes make it work if
66
+ # the option had spaces, like --cmd "queue.pl -sync y"
67
+ eval "${name}=\"$2\"";
68
+
69
+ # Check that Boolean-valued arguments are really Boolean.
70
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
71
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
72
+ exit 1;
73
+ fi
74
+ shift 2;
75
+ ;;
76
+
77
+ *) break;
78
+ esac
79
+ done
80
+
81
+ file_dir="${work_dir}/${file_folder_name}"
82
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
83
+
84
+ dataset="${file_dir}/dataset.xlsx"
85
+ train_dataset="${file_dir}/train.xlsx"
86
+ valid_dataset="${file_dir}/valid.xlsx"
87
+ evaluation_file="${file_dir}/evaluation.xlsx"
88
+ vocabulary_dir="${file_dir}/vocabulary"
89
+
90
+ $verbose && echo "system_version: ${system_version}"
91
+ $verbose && echo "file_folder_name: ${file_folder_name}"
92
+
93
+ if [ $system_version == "windows" ]; then
94
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
95
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
96
+ #source /data/local/bin/vm_sound_classification/bin/activate
97
+ alias python3='/data/local/bin/vm_sound_classification/bin/python3'
98
+ fi
99
+
100
+
101
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
102
+ $verbose && echo "stage 0: prepare data"
103
+ cd "${work_dir}" || exit 1
104
+ python3 step_1_prepare_data.py \
105
+ --file_dir "${file_dir}" \
106
+ --filename_patterns "${filename_patterns}" \
107
+ --train_dataset "${train_dataset}" \
108
+ --valid_dataset "${valid_dataset}" \
109
+ --label_plan "${label_plan}" \
110
+
111
+ fi
112
+
113
+
114
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
115
+ $verbose && echo "stage 1: make vocabulary"
116
+ cd "${work_dir}" || exit 1
117
+ python3 step_2_make_vocabulary.py \
118
+ --vocabulary_dir "${vocabulary_dir}" \
119
+ --train_dataset "${train_dataset}" \
120
+ --valid_dataset "${valid_dataset}" \
121
+
122
+ fi
123
+
124
+
125
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
126
+ $verbose && echo "stage 2: train model"
127
+ cd "${work_dir}" || exit 1
128
+ python3 step_3_train_model.py \
129
+ --vocabulary_dir "${vocabulary_dir}" \
130
+ --train_dataset "${train_dataset}" \
131
+ --valid_dataset "${valid_dataset}" \
132
+ --serialization_dir "${file_dir}" \
133
+ --config_file "${config_file}" \
134
+ --pretrained_model "${pretrained_model}" \
135
+
136
+ fi
137
+
138
+
139
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
140
+ $verbose && echo "stage 3: test model"
141
+ cd "${work_dir}" || exit 1
142
+ python3 step_4_evaluation_model.py \
143
+ --dataset "${dataset}" \
144
+ --vocabulary_dir "${vocabulary_dir}" \
145
+ --model_dir "${file_dir}/best" \
146
+ --output_file "${evaluation_file}" \
147
+
148
+ fi
149
+
150
+
151
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
152
+ $verbose && echo "stage 4: export model"
153
+ cd "${work_dir}" || exit 1
154
+ python3 step_5_export_models.py \
155
+ --vocabulary_dir "${vocabulary_dir}" \
156
+ --model_dir "${file_dir}/best" \
157
+ --serialization_dir "${file_dir}" \
158
+
159
+ fi
160
+
161
+
162
+ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
163
+ $verbose && echo "stage 5: collect files"
164
+ cd "${work_dir}" || exit 1
165
+
166
+ mkdir -p ${final_model_dir}
167
+
168
+ cp "${file_dir}/best"/* "${final_model_dir}"
169
+ cp -r "${file_dir}/vocabulary" "${final_model_dir}"
170
+
171
+ cp "${file_dir}/evaluation.xlsx" "${final_model_dir}/evaluation.xlsx"
172
+
173
+ cp "${file_dir}/trace_model.zip" "${final_model_dir}/trace_model.zip"
174
+ cp "${file_dir}/trace_quant_model.zip" "${final_model_dir}/trace_quant_model.zip"
175
+ cp "${file_dir}/script_model.zip" "${final_model_dir}/script_model.zip"
176
+ cp "${file_dir}/script_quant_model.zip" "${final_model_dir}/script_quant_model.zip"
177
+
178
+ cd "${final_model_dir}/.." || exit 1;
179
+
180
+ if [ -e "${final_model_name}.zip" ]; then
181
+ rm -rf "${final_model_name}_backup.zip"
182
+ mv "${final_model_name}.zip" "${final_model_name}_backup.zip"
183
+ fi
184
+
185
+ zip -r "${final_model_name}.zip" "${final_model_name}"
186
+ rm -rf "${final_model_name}"
187
+
188
+ fi
189
+
190
+
191
+ if [ ${stage} -le 6 ] && [ ${stop_stage} -ge 6 ]; then
192
+ $verbose && echo "stage 6: clear file_dir"
193
+ cd "${work_dir}" || exit 1
194
+
195
+ rm -rf "${file_dir}";
196
+
197
+ fi
examples/vm_sound_classification/run_batch.sh ADDED
@@ -0,0 +1,268 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+
4
+ # sound ch4
5
+
6
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch4 \
7
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
8
+ #--label_plan 2 \
9
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml"
10
+ #
11
+ #
12
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch4 \
13
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
14
+ #--label_plan 3 \
15
+ #--config_file "yaml/conv2d-classifier-3-ch4.yaml"
16
+ #
17
+ #
18
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch4 \
19
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
20
+ #--label_plan 4 \
21
+ #--config_file "yaml/conv2d-classifier-4-ch4.yaml"
22
+ #
23
+ #
24
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch4 \
25
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
26
+ #--label_plan 8 \
27
+ #--config_file "yaml/conv2d-classifier-8-ch4.yaml"
28
+
29
+
30
+ # sound ch8
31
+
32
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch8 \
33
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
34
+ #--label_plan 2 \
35
+ #--config_file "yaml/conv2d-classifier-2-ch8.yaml"
36
+ #
37
+ #
38
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch8 \
39
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
40
+ #--label_plan 3 \
41
+ #--config_file "yaml/conv2d-classifier-3-ch8.yaml"
42
+ #
43
+ #
44
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch8 \
45
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
46
+ #--label_plan 4 \
47
+ #--config_file "yaml/conv2d-classifier-4-ch8.yaml"
48
+ #
49
+ #
50
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch8 \
51
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
52
+ #--label_plan 8 \
53
+ #--config_file "yaml/conv2d-classifier-8-ch8.yaml"
54
+
55
+
56
+ # sound ch16
57
+
58
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch16 \
59
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
60
+ #--label_plan 2 \
61
+ #--config_file "yaml/conv2d-classifier-2-ch16.yaml"
62
+
63
+
64
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch16 \
65
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
66
+ #--label_plan 3 \
67
+ #--config_file "yaml/conv2d-classifier-3-ch16.yaml"
68
+ #
69
+ #
70
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch16 \
71
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
72
+ #--label_plan 4 \
73
+ #--config_file "yaml/conv2d-classifier-4-ch16.yaml"
74
+ #
75
+ #
76
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch16 \
77
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
78
+ #--label_plan 8 \
79
+ #--config_file "yaml/conv2d-classifier-8-ch16.yaml"
80
+
81
+
82
+ # sound ch32
83
+
84
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-2-ch32 \
85
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
86
+ #--label_plan 2 \
87
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml"
88
+ #
89
+ #
90
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-3-ch32 \
91
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
92
+ #--label_plan 3 \
93
+ #--config_file "yaml/conv2d-classifier-3-ch32.yaml"
94
+ #
95
+ #
96
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-4-ch32 \
97
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
98
+ #--label_plan 4 \
99
+ #--config_file "yaml/conv2d-classifier-4-ch32.yaml"
100
+
101
+
102
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name sound-8-ch32 \
103
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
104
+ #--label_plan 8 \
105
+ #--config_file "yaml/conv2d-classifier-8-ch32.yaml"
106
+
107
+
108
+ # pretrained voicemail
109
+
110
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-2-ch4 \
111
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
112
+ --label_plan 2-voicemail \
113
+ --config_file "yaml/conv2d-classifier-2-ch4.yaml"
114
+
115
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-2-ch32 \
116
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav" \
117
+ --label_plan 2-voicemail \
118
+ --config_file "yaml/conv2d-classifier-2-ch32.yaml"
119
+
120
+
121
+ # voicemail ch4
122
+
123
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-ph-2-ch4 \
124
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-PH/wav_finished/*/*.wav" \
125
+ --label_plan 2-voicemail \
126
+ --config_file "yaml/conv2d-classifier-2-ch4.yaml" \
127
+ --pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
128
+
129
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-sg-2-ch4 \
130
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-SG/wav_finished/*/*.wav" \
131
+ #--label_plan 2-voicemail \
132
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
133
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
134
+ #
135
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-us-2-ch4 \
136
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-US/wav_finished/*/*.wav" \
137
+ #--label_plan 2-voicemail \
138
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
139
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
140
+ #
141
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-mx-2-ch4 \
142
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-MX/wav_finished/*/*.wav" \
143
+ #--label_plan 2-voicemail \
144
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
145
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
146
+ #
147
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-pe-2-ch4 \
148
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-PE/wav_finished/*/*.wav" \
149
+ #--label_plan 2-voicemail \
150
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
151
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
152
+ #
153
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-id-id-2-ch4 \
154
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/id-ID/wav_finished/*/*.wav" \
155
+ #--label_plan 2-voicemail \
156
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
157
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
158
+ #
159
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ja-jp-2-ch4 \
160
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ja-JP/wav_finished/*/*.wav" \
161
+ #--label_plan 2-voicemail \
162
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
163
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
164
+ #
165
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ko-kr-2-ch4 \
166
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ko-KR/wav_finished/*/*.wav" \
167
+ #--label_plan 2-voicemail \
168
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
169
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
170
+ #
171
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch4 \
172
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
173
+ #--label_plan 2-voicemail \
174
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
175
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
176
+ #
177
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-pt-br-2-ch4 \
178
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/pt-BR/wav_finished/*/*.wav" \
179
+ #--label_plan 2-voicemail \
180
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
181
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
182
+ #
183
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-th-th-2-ch4 \
184
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/th-TH/wav_finished/*/*.wav" \
185
+ #--label_plan 2-voicemail \
186
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
187
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
188
+ #
189
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-zh-tw-2-ch4 \
190
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/zh-TW/wav_finished/*/*.wav" \
191
+ #--label_plan 2-voicemail \
192
+ #--config_file "yaml/conv2d-classifier-2-ch4.yaml" \
193
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch4.zip"
194
+
195
+
196
+ # voicemail ch32
197
+
198
+ sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-ph-2-ch32 \
199
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-PH/wav_finished/*/*.wav" \
200
+ --label_plan 2-voicemail \
201
+ --config_file "yaml/conv2d-classifier-2-ch32.yaml" \
202
+ --pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
203
+
204
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-sg-2-ch32 \
205
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-SG/wav_finished/*/*.wav" \
206
+ #--label_plan 2-voicemail \
207
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
208
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
209
+ #
210
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-en-us-2-ch32 \
211
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/en-US/wav_finished/*/*.wav" \
212
+ #--label_plan 2-voicemail \
213
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
214
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
215
+ #
216
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-mx-2-ch32 \
217
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-MX/wav_finished/*/*.wav" \
218
+ #--label_plan 2-voicemail \
219
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
220
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
221
+ #
222
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-es-pe-2-ch32 \
223
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/es-PE/wav_finished/*/*.wav" \
224
+ #--label_plan 2-voicemail \
225
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
226
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
227
+ #
228
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-id-id-2-ch32 \
229
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/id-ID/wav_finished/*/*.wav" \
230
+ #--label_plan 2-voicemail \
231
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
232
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
233
+ #
234
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ja-jp-2-ch32 \
235
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ja-JP/wav_finished/*/*.wav" \
236
+ #--label_plan 2-voicemail \
237
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
238
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
239
+ #
240
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ko-kr-2-ch32 \
241
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ko-KR/wav_finished/*/*.wav" \
242
+ #--label_plan 2-voicemail \
243
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
244
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
245
+ #
246
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-ms-my-2-ch32 \
247
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/ms-MY/wav_finished/*/*.wav" \
248
+ #--label_plan 2-voicemail \
249
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
250
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
251
+ #
252
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-pt-br-2-ch32 \
253
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/pt-BR/wav_finished/*/*.wav" \
254
+ #--label_plan 2-voicemail \
255
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
256
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
257
+ #
258
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-th-th-2-ch32 \
259
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/th-TH/wav_finished/*/*.wav" \
260
+ #--label_plan 2-voicemail \
261
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
262
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
263
+ #
264
+ #sh run.sh --stage 0 --stop_stage 6 --system_version centos --file_folder_name file_dir --final_model_name voicemail-zh-tw-2-ch32 \
265
+ #--filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/zh-TW/wav_finished/*/*.wav" \
266
+ #--label_plan 2-voicemail \
267
+ #--config_file "yaml/conv2d-classifier-2-ch32.yaml" \
268
+ #--pretrained_model "/data/tianxing/PycharmProjects/vm_sound_classification/trained_models/voicemail-2-ch32.zip"
examples/vm_sound_classification/step_1_prepare_data.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from glob import glob
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import random
9
+ import sys
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import pandas as pd
15
+ from scipy.io import wavfile
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--file_dir", default="./", type=str)
22
+ parser.add_argument("--filename_patterns", type=str)
23
+
24
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
25
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
26
+
27
+ parser.add_argument("--label_plan", default="4", type=str)
28
+
29
+ args = parser.parse_args()
30
+ return args
31
+
32
+
33
+ def get_dataset(args):
34
+ filename_patterns = args.filename_patterns
35
+ filename_patterns = filename_patterns.split(" ")
36
+ print(filename_patterns)
37
+
38
+ file_dir = Path(args.file_dir)
39
+ file_dir.mkdir(exist_ok=True)
40
+
41
+ if args.label_plan == "2-voicemail":
42
+ label_map = {
43
+ "bell": "voicemail",
44
+ "white_noise": "non_voicemail",
45
+ "low_white_noise": "non_voicemail",
46
+ "high_white_noise": "non_voicemail",
47
+ # "music": "non_voicemail",
48
+ "mute": "non_voicemail",
49
+ "noise": "non_voicemail",
50
+ "noise_mute": "non_voicemail",
51
+ "voice": "non_voicemail",
52
+ "voicemail": "voicemail",
53
+ }
54
+ elif args.label_plan == "2":
55
+ label_map = {
56
+ "bell": "non_voice",
57
+ "white_noise": "non_voice",
58
+ "low_white_noise": "non_voice",
59
+ "high_white_noise": "non_voice",
60
+ "music": "non_voice",
61
+ "mute": "non_voice",
62
+ "noise": "non_voice",
63
+ "noise_mute": "non_voice",
64
+ "voice": "voice",
65
+ "voicemail": "voice",
66
+ }
67
+ elif args.label_plan == "3":
68
+ label_map = {
69
+ "bell": "voicemail",
70
+ "white_noise": "mute",
71
+ "low_white_noise": "mute",
72
+ "high_white_noise": "mute",
73
+ # "music": "music",
74
+ "mute": "mute",
75
+ "noise": "voice_or_noise",
76
+ "noise_mute": "voice_or_noise",
77
+ "voice": "voice_or_noise",
78
+ "voicemail": "voicemail",
79
+ }
80
+ elif args.label_plan == "4":
81
+ label_map = {
82
+ "bell": "voicemail",
83
+ "white_noise": "mute",
84
+ "low_white_noise": "mute",
85
+ "high_white_noise": "mute",
86
+ # "music": "music",
87
+ "mute": "mute",
88
+ "noise": "noise",
89
+ "noise_mute": "noise",
90
+ "voice": "voice",
91
+ "voicemail": "voicemail",
92
+ }
93
+ elif args.label_plan == "8":
94
+ label_map = {
95
+ "bell": "bell",
96
+ "white_noise": "white_noise",
97
+ "low_white_noise": "white_noise",
98
+ "high_white_noise": "white_noise",
99
+ "music": "music",
100
+ "mute": "mute",
101
+ "noise": "noise",
102
+ "noise_mute": "noise_mute",
103
+ "voice": "voice",
104
+ "voicemail": "voicemail",
105
+ }
106
+ else:
107
+ raise AssertionError
108
+
109
+ result = list()
110
+ for filename_pattern in filename_patterns:
111
+ filename_list = glob(filename_pattern)
112
+ for filename in tqdm(filename_list):
113
+ filename = Path(filename)
114
+ sample_rate, signal = wavfile.read(filename.as_posix())
115
+ if len(signal) < sample_rate * 2:
116
+ continue
117
+
118
+ folder = filename.parts[-2]
119
+ country = filename.parts[-4]
120
+
121
+ if folder not in label_map.keys():
122
+ continue
123
+
124
+ labels = label_map[folder]
125
+
126
+ random1 = random.random()
127
+ random2 = random.random()
128
+
129
+ result.append({
130
+ "filename": filename,
131
+ "folder": folder,
132
+ "category": country,
133
+ "labels": labels,
134
+ "random1": random1,
135
+ "random2": random2,
136
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
137
+ })
138
+
139
+ df = pd.DataFrame(result)
140
+ pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count")
141
+ print(pivot_table)
142
+
143
+ df = df.sort_values(by=["random1"], ascending=False)
144
+ df.to_excel(
145
+ file_dir / "dataset.xlsx",
146
+ index=False,
147
+ # encoding="utf_8_sig"
148
+ )
149
+
150
+ return
151
+
152
+
153
+ def split_dataset(args):
154
+ """分割训练集, 测试集"""
155
+ file_dir = Path(args.file_dir)
156
+ file_dir.mkdir(exist_ok=True)
157
+
158
+ df = pd.read_excel(file_dir / "dataset.xlsx")
159
+
160
+ train = list()
161
+ test = list()
162
+
163
+ for i, row in df.iterrows():
164
+ flag = row["flag"]
165
+ if flag == "TRAIN":
166
+ train.append(row)
167
+ else:
168
+ test.append(row)
169
+
170
+ train = pd.DataFrame(train)
171
+ train.to_excel(
172
+ args.train_dataset,
173
+ index=False,
174
+ # encoding="utf_8_sig"
175
+ )
176
+ test = pd.DataFrame(test)
177
+ test.to_excel(
178
+ args.valid_dataset,
179
+ index=False,
180
+ # encoding="utf_8_sig"
181
+ )
182
+
183
+ return
184
+
185
+
186
+ def main():
187
+ args = get_args()
188
+ get_dataset(args)
189
+ split_dataset(args)
190
+ return
191
+
192
+
193
+ if __name__ == "__main__":
194
+ main()
examples/vm_sound_classification/step_2_make_vocabulary.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import pandas as pd
12
+
13
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
19
+
20
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
21
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
22
+
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = get_args()
29
+
30
+ train_dataset = pd.read_excel(args.train_dataset)
31
+ valid_dataset = pd.read_excel(args.valid_dataset)
32
+
33
+ vocabulary = Vocabulary()
34
+
35
+ # train
36
+ for i, row in train_dataset.iterrows():
37
+ label = row["labels"]
38
+ vocabulary.add_token_to_namespace(label, namespace="labels")
39
+
40
+ # valid
41
+ for i, row in valid_dataset.iterrows():
42
+ label = row["labels"]
43
+ vocabulary.add_token_to_namespace(label, namespace="labels")
44
+
45
+ vocabulary.save_to_files(args.vocabulary_dir)
46
+
47
+ return
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
examples/vm_sound_classification/step_3_train_model.py ADDED
@@ -0,0 +1,367 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import random
12
+ import sys
13
+ import shutil
14
+ import tempfile
15
+ from typing import List
16
+ import zipfile
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
27
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
28
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
29
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
30
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
31
+ from toolbox.torchaudio.models.cnn_audio_classifier.configuration_cnn_audio_classifier import CnnAudioClassifierConfig
32
+
33
+
34
+ def get_args():
35
+ parser = argparse.ArgumentParser()
36
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
37
+
38
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
39
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
40
+
41
+ parser.add_argument("--max_epochs", default=100, type=int)
42
+
43
+ parser.add_argument("--batch_size", default=64, type=int)
44
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
45
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
46
+ parser.add_argument("--patience", default=5, type=int)
47
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
48
+ parser.add_argument("--seed", default=0, type=int)
49
+
50
+ parser.add_argument("--config_file", default="conv2d_classifier.yaml", type=str)
51
+ parser.add_argument(
52
+ "--pretrained_model",
53
+ # default=(project_path / "trained_models/voicemail-en-sg-2-ch4.zip").as_posix(),
54
+ default="null",
55
+ type=str
56
+ )
57
+
58
+ args = parser.parse_args()
59
+ return args
60
+
61
+
62
+ def logging_config(file_dir: str):
63
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
64
+
65
+ logging.basicConfig(format=fmt,
66
+ datefmt="%m/%d/%Y %H:%M:%S",
67
+ level=logging.DEBUG)
68
+ file_handler = TimedRotatingFileHandler(
69
+ filename=os.path.join(file_dir, "main.log"),
70
+ encoding="utf-8",
71
+ when="D",
72
+ interval=1,
73
+ backupCount=7
74
+ )
75
+ file_handler.setLevel(logging.INFO)
76
+ file_handler.setFormatter(logging.Formatter(fmt))
77
+ logger = logging.getLogger(__name__)
78
+ logger.addHandler(file_handler)
79
+
80
+ return logger
81
+
82
+
83
+ class CollateFunction(object):
84
+ def __init__(self):
85
+ pass
86
+
87
+ def __call__(self, batch: List[dict]):
88
+ array_list = list()
89
+ label_list = list()
90
+ for sample in batch:
91
+ array = sample["waveform"]
92
+ label = sample["label"]
93
+
94
+ l = len(array)
95
+ if l < 16000:
96
+ delta = int(16000 - l)
97
+ array = np.concatenate([array, np.zeros(shape=(delta,), dtype=np.float32)], axis=-1)
98
+ if l > 16000:
99
+ array = array[:16000]
100
+
101
+ array_list.append(array)
102
+ label_list.append(label)
103
+
104
+ array_list = torch.stack(array_list)
105
+ label_list = torch.stack(label_list)
106
+ return array_list, label_list
107
+
108
+
109
+ collate_fn = CollateFunction()
110
+
111
+
112
+ def main():
113
+ args = get_args()
114
+
115
+ serialization_dir = Path(args.serialization_dir)
116
+ serialization_dir.mkdir(parents=True, exist_ok=True)
117
+
118
+ logger = logging_config(serialization_dir)
119
+
120
+ random.seed(args.seed)
121
+ np.random.seed(args.seed)
122
+ torch.manual_seed(args.seed)
123
+ logger.info("set seed: {}".format(args.seed))
124
+
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+ n_gpu = torch.cuda.device_count()
127
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
128
+
129
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
130
+
131
+ # datasets
132
+ logger.info("prepare datasets")
133
+ train_dataset = WaveClassifierExcelDataset(
134
+ vocab=vocabulary,
135
+ excel_file=args.train_dataset,
136
+ category=None,
137
+ category_field="category",
138
+ label_field="labels",
139
+ expected_sample_rate=8000,
140
+ max_wave_value=32768.0,
141
+ )
142
+ valid_dataset = WaveClassifierExcelDataset(
143
+ vocab=vocabulary,
144
+ excel_file=args.valid_dataset,
145
+ category=None,
146
+ category_field="category",
147
+ label_field="labels",
148
+ expected_sample_rate=8000,
149
+ max_wave_value=32768.0,
150
+ )
151
+ train_data_loader = DataLoader(
152
+ dataset=train_dataset,
153
+ batch_size=args.batch_size,
154
+ shuffle=True,
155
+ # Linux 系统中可以使用多个子进程加��数据, 而在 Windows 系统中不能.
156
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
157
+ collate_fn=collate_fn,
158
+ pin_memory=False,
159
+ # prefetch_factor=64,
160
+ )
161
+ valid_data_loader = DataLoader(
162
+ dataset=valid_dataset,
163
+ batch_size=args.batch_size,
164
+ shuffle=True,
165
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
166
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
167
+ collate_fn=collate_fn,
168
+ pin_memory=False,
169
+ # prefetch_factor=64,
170
+ )
171
+
172
+ # models
173
+ logger.info(f"prepare models. config_file: {args.config_file}")
174
+ config = CnnAudioClassifierConfig.from_pretrained(
175
+ pretrained_model_name_or_path=args.config_file,
176
+ # num_labels=vocabulary.get_vocab_size(namespace="labels")
177
+ )
178
+ if not config.cls_head_param["num_labels"] == vocabulary.get_vocab_size(namespace="labels"):
179
+ raise AssertionError("expected num labels: {} instead of {}.".format(
180
+ vocabulary.get_vocab_size(namespace="labels"),
181
+ config.cls_head_param["num_labels"],
182
+ ))
183
+ model = WaveClassifierPretrainedModel(
184
+ config=config,
185
+ )
186
+
187
+ if args.pretrained_model is not None and os.path.exists(args.pretrained_model):
188
+ logger.info(f"load pretrained model state dict from: {args.pretrained_model}")
189
+ pretrained_model = Path(args.pretrained_model)
190
+ with zipfile.ZipFile(pretrained_model.as_posix(), "r") as f_zip:
191
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
192
+ # print(out_root.as_posix())
193
+ if out_root.exists():
194
+ shutil.rmtree(out_root.as_posix())
195
+ out_root.mkdir(parents=True, exist_ok=True)
196
+ f_zip.extractall(path=out_root)
197
+
198
+ tgt_path = out_root / pretrained_model.stem
199
+ model_pt_file = tgt_path / "model.pt"
200
+ with open(model_pt_file, "rb") as f:
201
+ state_dict = torch.load(f, map_location="cpu")
202
+ model.load_state_dict(state_dict=state_dict)
203
+
204
+ model.to(device)
205
+ model.train()
206
+
207
+ # optimizer
208
+ logger.info("prepare optimizer, lr_scheduler, loss_fn, categorical_accuracy")
209
+ param_optimizer = model.parameters()
210
+ optimizer = torch.optim.Adam(
211
+ param_optimizer,
212
+ lr=args.learning_rate,
213
+ )
214
+ # lr_scheduler = torch.optim.lr_scheduler.StepLR(
215
+ # optimizer,
216
+ # step_size=2000
217
+ # )
218
+ lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(
219
+ optimizer,
220
+ milestones=[10000, 20000, 30000, 40000, 50000], gamma=0.5
221
+ )
222
+ focal_loss = FocalLoss(
223
+ num_classes=vocabulary.get_vocab_size(namespace="labels"),
224
+ reduction="mean",
225
+ )
226
+ categorical_accuracy = CategoricalAccuracy()
227
+
228
+ # training loop
229
+ logger.info("training")
230
+
231
+ training_loss = 10000000000
232
+ training_accuracy = 0.
233
+ evaluation_loss = 10000000000
234
+ evaluation_accuracy = 0.
235
+
236
+ model_list = list()
237
+ best_idx_epoch = None
238
+ best_accuracy = None
239
+ patience_count = 0
240
+
241
+ for idx_epoch in range(args.max_epochs):
242
+ categorical_accuracy.reset()
243
+ total_loss = 0.
244
+ total_examples = 0.
245
+ progress_bar = tqdm(
246
+ total=len(train_data_loader),
247
+ desc="Training; epoch: {}".format(idx_epoch),
248
+ )
249
+ for batch in train_data_loader:
250
+ input_ids, label_ids = batch
251
+ input_ids = input_ids.to(device)
252
+ label_ids: torch.LongTensor = label_ids.to(device).long()
253
+
254
+ logits = model.forward(input_ids)
255
+ loss = focal_loss.forward(logits, label_ids.view(-1))
256
+ categorical_accuracy(logits, label_ids)
257
+
258
+ total_loss += loss.item()
259
+ total_examples += input_ids.size(0)
260
+
261
+ optimizer.zero_grad()
262
+ loss.backward()
263
+ optimizer.step()
264
+ lr_scheduler.step()
265
+
266
+ training_loss = total_loss / total_examples
267
+ training_loss = round(training_loss, 4)
268
+ training_accuracy = categorical_accuracy.get_metric()["accuracy"]
269
+ training_accuracy = round(training_accuracy, 4)
270
+
271
+ progress_bar.update(1)
272
+ progress_bar.set_postfix({
273
+ "training_loss": training_loss,
274
+ "training_accuracy": training_accuracy,
275
+ })
276
+
277
+ categorical_accuracy.reset()
278
+ total_loss = 0.
279
+ total_examples = 0.
280
+ progress_bar = tqdm(
281
+ total=len(valid_data_loader),
282
+ desc="Evaluation; epoch: {}".format(idx_epoch),
283
+ )
284
+ for batch in valid_data_loader:
285
+ input_ids, label_ids = batch
286
+ input_ids = input_ids.to(device)
287
+ label_ids: torch.LongTensor = label_ids.to(device).long()
288
+
289
+ with torch.no_grad():
290
+ logits = model.forward(input_ids)
291
+ loss = focal_loss.forward(logits, label_ids.view(-1))
292
+ categorical_accuracy(logits, label_ids)
293
+
294
+ total_loss += loss.item()
295
+ total_examples += input_ids.size(0)
296
+
297
+ evaluation_loss = total_loss / total_examples
298
+ evaluation_loss = round(evaluation_loss, 4)
299
+ evaluation_accuracy = categorical_accuracy.get_metric()["accuracy"]
300
+ evaluation_accuracy = round(evaluation_accuracy, 4)
301
+
302
+ progress_bar.update(1)
303
+ progress_bar.set_postfix({
304
+ "evaluation_loss": evaluation_loss,
305
+ "evaluation_accuracy": evaluation_accuracy,
306
+ })
307
+
308
+ # save path
309
+ epoch_dir = serialization_dir / "epoch-{}".format(idx_epoch)
310
+ epoch_dir.mkdir(parents=True, exist_ok=False)
311
+
312
+ # save models
313
+ model.save_pretrained(epoch_dir.as_posix())
314
+
315
+ model_list.append(epoch_dir)
316
+ if len(model_list) >= args.num_serialized_models_to_keep:
317
+ model_to_delete: Path = model_list.pop(0)
318
+ shutil.rmtree(model_to_delete.as_posix())
319
+
320
+ # save metric
321
+ if best_accuracy is None:
322
+ best_idx_epoch = idx_epoch
323
+ best_accuracy = evaluation_accuracy
324
+ elif evaluation_accuracy > best_accuracy:
325
+ best_idx_epoch = idx_epoch
326
+ best_accuracy = evaluation_accuracy
327
+ else:
328
+ pass
329
+
330
+ metrics = {
331
+ "idx_epoch": idx_epoch,
332
+ "best_idx_epoch": best_idx_epoch,
333
+ "best_accuracy": best_accuracy,
334
+ "training_loss": training_loss,
335
+ "training_accuracy": training_accuracy,
336
+ "evaluation_loss": evaluation_loss,
337
+ "evaluation_accuracy": evaluation_accuracy,
338
+ "learning_rate": optimizer.param_groups[0]['lr'],
339
+ }
340
+ metrics_filename = epoch_dir / "metrics_epoch.json"
341
+ with open(metrics_filename, "w", encoding="utf-8") as f:
342
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
343
+
344
+ # save best
345
+ best_dir = serialization_dir / "best"
346
+ if best_idx_epoch == idx_epoch:
347
+ if best_dir.exists():
348
+ shutil.rmtree(best_dir)
349
+ shutil.copytree(epoch_dir, best_dir)
350
+
351
+ # early stop
352
+ early_stop_flag = False
353
+ if best_idx_epoch == idx_epoch:
354
+ patience_count = 0
355
+ else:
356
+ patience_count += 1
357
+ if patience_count >= args.patience:
358
+ early_stop_flag = True
359
+
360
+ # early stop
361
+ if early_stop_flag:
362
+ break
363
+ return
364
+
365
+
366
+ if __name__ == "__main__":
367
+ main()
examples/vm_sound_classification/step_4_evaluation_model.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import pandas as pd
19
+ from scipy.io import wavfile
20
+ import torch
21
+ from tqdm import tqdm
22
+
23
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
24
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
25
+
26
+
27
+ def get_args():
28
+ parser = argparse.ArgumentParser()
29
+ parser.add_argument("--dataset", default="dataset.xlsx", type=str)
30
+
31
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
32
+ parser.add_argument("--model_dir", default="best", type=str)
33
+
34
+ parser.add_argument("--output_file", default="evaluation.xlsx", type=str)
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def logging_config():
41
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
42
+
43
+ logging.basicConfig(format=fmt,
44
+ datefmt="%m/%d/%Y %H:%M:%S",
45
+ level=logging.DEBUG)
46
+ stream_handler = logging.StreamHandler()
47
+ stream_handler.setLevel(logging.INFO)
48
+ stream_handler.setFormatter(logging.Formatter(fmt))
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+ return logger
53
+
54
+
55
+ def main():
56
+ args = get_args()
57
+
58
+ logger = logging_config()
59
+
60
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
61
+ n_gpu = torch.cuda.device_count()
62
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
63
+
64
+ logger.info("prepare vocabulary, model")
65
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
66
+
67
+ model = WaveClassifierPretrainedModel.from_pretrained(
68
+ pretrained_model_name_or_path=args.model_dir,
69
+ )
70
+ model.to(device)
71
+ model.eval()
72
+
73
+ logger.info("read excel")
74
+ df = pd.read_excel(args.dataset)
75
+ result = list()
76
+
77
+ total_correct = 0
78
+ total_examples = 0
79
+
80
+ progress_bar = tqdm(total=len(df), desc="Evaluation")
81
+ for i, row in df.iterrows():
82
+ filename = row["filename"]
83
+ ground_true = row["labels"]
84
+
85
+ sample_rate, waveform = wavfile.read(filename)
86
+ waveform = waveform / (1 << 15)
87
+ waveform = torch.tensor(waveform, dtype=torch.float32)
88
+ waveform = torch.unsqueeze(waveform, dim=0)
89
+ waveform = waveform.to(device)
90
+
91
+ with torch.no_grad():
92
+ logits = model.forward(waveform)
93
+ probs = torch.nn.functional.softmax(logits, dim=-1)
94
+ label_idx = torch.argmax(probs, dim=-1)
95
+
96
+ label_idx = label_idx.cpu()
97
+ probs = probs.cpu()
98
+
99
+ label_idx = label_idx.numpy()[0]
100
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
101
+ prob = probs[0][label_idx].numpy()
102
+
103
+ correct = 1 if label_str == ground_true else 0
104
+ row_ = dict(row)
105
+ row_["predict"] = label_str
106
+ row_["prob"] = prob
107
+ row_["correct"] = correct
108
+ result.append(row_)
109
+
110
+ total_examples += 1
111
+ total_correct += correct
112
+ accuracy = total_correct / total_examples
113
+
114
+ progress_bar.update(1)
115
+ progress_bar.set_postfix({
116
+ "accuracy": accuracy,
117
+ })
118
+
119
+ result = pd.DataFrame(result)
120
+ result.to_excel(
121
+ args.output_file,
122
+ index=False
123
+ )
124
+ return
125
+
126
+
127
+ if __name__ == '__main__':
128
+ main()
examples/vm_sound_classification/step_5_export_models.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
22
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
23
+
24
+
25
+ def get_args():
26
+ parser = argparse.ArgumentParser()
27
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
28
+ parser.add_argument("--model_dir", default="best", type=str)
29
+
30
+ parser.add_argument("--serialization_dir", default="serialization_dir", type=str)
31
+
32
+ args = parser.parse_args()
33
+ return args
34
+
35
+
36
+ def logging_config():
37
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
38
+
39
+ logging.basicConfig(format=fmt,
40
+ datefmt="%m/%d/%Y %H:%M:%S",
41
+ level=logging.DEBUG)
42
+ stream_handler = logging.StreamHandler()
43
+ stream_handler.setLevel(logging.INFO)
44
+ stream_handler.setFormatter(logging.Formatter(fmt))
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+ return logger
49
+
50
+
51
+ def main():
52
+ args = get_args()
53
+
54
+ serialization_dir = Path(args.serialization_dir)
55
+
56
+ logger = logging_config()
57
+
58
+ logger.info("export models on CPU")
59
+ device = torch.device("cpu")
60
+
61
+ logger.info("prepare vocabulary, model")
62
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
63
+
64
+ model = WaveClassifierPretrainedModel.from_pretrained(
65
+ pretrained_model_name_or_path=args.model_dir,
66
+ num_labels=vocabulary.get_vocab_size(namespace="labels")
67
+ )
68
+ model.to(device)
69
+ model.eval()
70
+
71
+ waveform = 0 + 25 * np.random.randn(16000,)
72
+ waveform = np.array(waveform, dtype=np.int16)
73
+ waveform = waveform / (1 << 15)
74
+ waveform = torch.tensor(waveform, dtype=torch.float32)
75
+ waveform = torch.unsqueeze(waveform, dim=0)
76
+ waveform = waveform.to(device)
77
+
78
+ logger.info("export jit models")
79
+ example_inputs = (waveform,)
80
+
81
+ # trace model
82
+ trace_model = torch.jit.trace(func=model, example_inputs=example_inputs, strict=False)
83
+ trace_model.save(serialization_dir / "trace_model.zip")
84
+
85
+ # quantization trace model (not work on GPU)
86
+ quantized_model = torch.quantization.quantize_dynamic(
87
+ model, {torch.nn.Linear}, dtype=torch.qint8
88
+ )
89
+ trace_quant_model = torch.jit.trace(func=quantized_model, example_inputs=example_inputs, strict=False)
90
+ trace_quant_model.save(serialization_dir / "trace_quant_model.zip")
91
+
92
+ # script model
93
+ script_model = torch.jit.script(obj=model)
94
+ script_model.save(serialization_dir / "script_model.zip")
95
+
96
+ # quantization script model (not work on GPU)
97
+ quantized_model = torch.quantization.quantize_dynamic(
98
+ model, {torch.nn.Linear}, dtype=torch.qint8
99
+ )
100
+ script_quant_model = torch.jit.script(quantized_model)
101
+ script_quant_model.save(serialization_dir / "script_quant_model.zip")
102
+ return
103
+
104
+
105
+ if __name__ == '__main__':
106
+ main()
examples/vm_sound_classification/step_6_infer.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import zipfile
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ from scipy.io import wavfile
15
+ import torch
16
+
17
+ from project_settings import project_path
18
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--model_file",
25
+ default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
26
+ type=str
27
+ )
28
+ parser.add_argument(
29
+ "--wav_file",
30
+ default=r"C:\Users\tianx\Desktop\a073d03d-d280-46df-9b2d-d904965f4500_zh-CN_h3f25ivhb0c0_1719478037746.wav",
31
+ type=str
32
+ )
33
+
34
+ parser.add_argument("--device", default="cpu", type=str)
35
+
36
+ args = parser.parse_args()
37
+ return args
38
+
39
+
40
+ def main():
41
+ args = get_args()
42
+
43
+ model_file = Path(args.model_file)
44
+
45
+ device = torch.device(args.device)
46
+
47
+ with zipfile.ZipFile(model_file, "r") as f_zip:
48
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
49
+ print(out_root.as_posix())
50
+ if out_root.exists():
51
+ shutil.rmtree(out_root.as_posix())
52
+ out_root.mkdir(parents=True, exist_ok=True)
53
+ f_zip.extractall(path=out_root)
54
+
55
+ tgt_path = out_root / model_file.stem
56
+ jit_model_file = tgt_path / "trace_model.zip"
57
+ vocab_path = tgt_path / "vocabulary"
58
+
59
+ with open(jit_model_file.as_posix(), "rb") as f:
60
+ model = torch.jit.load(f)
61
+ model.to(device)
62
+ model.eval()
63
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
64
+
65
+ # infer
66
+ sample_rate, waveform = wavfile.read(args.wav_file)
67
+ waveform = waveform[:16000]
68
+ waveform = waveform / (1 << 15)
69
+ waveform = torch.tensor(waveform, dtype=torch.float32)
70
+ waveform = torch.unsqueeze(waveform, dim=0)
71
+ waveform = waveform.to(device)
72
+
73
+ with torch.no_grad():
74
+ logits = model.forward(waveform)
75
+ probs = torch.nn.functional.softmax(logits, dim=-1)
76
+ label_idx = torch.argmax(probs, dim=-1)
77
+
78
+ label_idx = label_idx.cpu()
79
+ probs = probs.cpu()
80
+
81
+ label_idx = label_idx.numpy()[0]
82
+ prob = probs.numpy()[0][label_idx]
83
+
84
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
85
+ print(label_str)
86
+ print(prob)
87
+ return
88
+
89
+
90
+ if __name__ == '__main__':
91
+ main()
examples/vm_sound_classification/step_7_test_model.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import shutil
7
+ import sys
8
+ import tempfile
9
+ import zipfile
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ from scipy.io import wavfile
15
+ import torch
16
+
17
+ from project_settings import project_path
18
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
19
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveClassifierPretrainedModel
20
+
21
+
22
+ def get_args():
23
+ parser = argparse.ArgumentParser()
24
+ parser.add_argument(
25
+ "--model_file",
26
+ default=(project_path / "trained_models/vm_sound_classification3.zip").as_posix(),
27
+ type=str
28
+ )
29
+ parser.add_argument(
30
+ "--wav_file",
31
+ default=r"C:\Users\tianx\Desktop\4b284733-0be3-4a48-abbb-615b32ac44b7_6ndddc2szlh0.wav",
32
+ type=str
33
+ )
34
+
35
+ parser.add_argument("--device", default="cpu", type=str)
36
+
37
+ args = parser.parse_args()
38
+ return args
39
+
40
+
41
+ def main():
42
+ args = get_args()
43
+
44
+ model_file = Path(args.model_file)
45
+
46
+ device = torch.device(args.device)
47
+
48
+ with zipfile.ZipFile(model_file, "r") as f_zip:
49
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
50
+ print(out_root)
51
+ if out_root.exists():
52
+ shutil.rmtree(out_root.as_posix())
53
+ out_root.mkdir(parents=True, exist_ok=True)
54
+ f_zip.extractall(path=out_root)
55
+
56
+ tgt_path = out_root / model_file.stem
57
+ vocab_path = tgt_path / "vocabulary"
58
+
59
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
60
+
61
+ model = WaveClassifierPretrainedModel.from_pretrained(
62
+ pretrained_model_name_or_path=tgt_path.as_posix(),
63
+ )
64
+ model.to(device)
65
+ model.eval()
66
+
67
+ # infer
68
+ sample_rate, waveform = wavfile.read(args.wav_file)
69
+ waveform = waveform[:16000]
70
+ waveform = waveform / (1 << 15)
71
+ waveform = torch.tensor(waveform, dtype=torch.float32)
72
+ waveform = torch.unsqueeze(waveform, dim=0)
73
+ waveform = waveform.to(device)
74
+ print(waveform.shape)
75
+ with torch.no_grad():
76
+ logits = model.forward(waveform)
77
+ probs = torch.nn.functional.softmax(logits, dim=-1)
78
+ label_idx = torch.argmax(probs, dim=-1)
79
+
80
+ label_idx = label_idx.cpu()
81
+ probs = probs.cpu()
82
+
83
+ label_idx = label_idx.numpy()[0]
84
+ prob = probs.numpy()[0][label_idx]
85
+
86
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
87
+ print(label_str)
88
+ print(prob)
89
+ return
90
+
91
+
92
+ if __name__ == '__main__':
93
+ main()
examples/vm_sound_classification/stop.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch16.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 16
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 16
23
+ out_channels: 16
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 16
30
+ out_channels: 16
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 432
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 2
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch32.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 32
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 32
23
+ out_channels: 32
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 32
30
+ out_channels: 32
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 864
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 2
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch4.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 4
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 4
23
+ out_channels: 4
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 4
30
+ out_channels: 4
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 108
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 2
examples/vm_sound_classification/yaml/conv2d-classifier-2-ch8.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 8
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 8
23
+ out_channels: 8
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 8
30
+ out_channels: 8
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 216
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 2
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch16.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 16
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 16
23
+ out_channels: 16
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 16
30
+ out_channels: 16
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 432
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 3
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch32.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 32
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 32
23
+ out_channels: 32
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 32
30
+ out_channels: 32
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 864
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 3
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch4.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 4
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 4
23
+ out_channels: 4
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 4
30
+ out_channels: 4
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 108
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 3
examples/vm_sound_classification/yaml/conv2d-classifier-3-ch8.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 8
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 8
23
+ out_channels: 8
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 8
30
+ out_channels: 8
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 216
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 3
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch16.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 16
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 16
23
+ out_channels: 16
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 16
30
+ out_channels: 16
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 432
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 4
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch32.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 32
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 32
23
+ out_channels: 32
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 32
30
+ out_channels: 32
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 864
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 4
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch4.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 4
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 4
23
+ out_channels: 4
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 4
30
+ out_channels: 4
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 108
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 4
examples/vm_sound_classification/yaml/conv2d-classifier-4-ch8.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 8
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 8
23
+ out_channels: 8
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 8
30
+ out_channels: 8
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 216
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 4
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch16.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 16
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 16
23
+ out_channels: 16
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 16
30
+ out_channels: 16
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 432
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 8
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch32.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 32
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 32
23
+ out_channels: 32
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 32
30
+ out_channels: 32
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 864
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 8
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch4.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 4
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 4
23
+ out_channels: 4
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 4
30
+ out_channels: 4
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 108
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 8
examples/vm_sound_classification/yaml/conv2d-classifier-8-ch8.yaml ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_name: "cnn_audio_classifier"
2
+
3
+ mel_spectrogram_param:
4
+ sample_rate: 8000
5
+ n_fft: 512
6
+ win_length: 200
7
+ hop_length: 80
8
+ f_min: 10
9
+ f_max: 3800
10
+ window_fn: hamming
11
+ n_mels: 80
12
+
13
+ conv2d_block_param_list:
14
+ - batch_norm: true
15
+ in_channels: 1
16
+ out_channels: 8
17
+ kernel_size: 3
18
+ stride: 1
19
+ dilation: 3
20
+ activation: relu
21
+ dropout: 0.1
22
+ - in_channels: 8
23
+ out_channels: 8
24
+ kernel_size: 5
25
+ stride: 2
26
+ dilation: 3
27
+ activation: relu
28
+ dropout: 0.1
29
+ - in_channels: 8
30
+ out_channels: 8
31
+ kernel_size: 3
32
+ stride: 1
33
+ dilation: 2
34
+ activation: relu
35
+ dropout: 0.1
36
+
37
+ cls_head_param:
38
+ input_dim: 216
39
+ num_layers: 2
40
+ hidden_dims:
41
+ - 128
42
+ - 32
43
+ activations: relu
44
+ dropout: 0.1
45
+ num_labels: 8
examples/vm_sound_classification8/requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.10.1
2
+ torchaudio==0.10.1
3
+ fsspec==2022.1.0
4
+ librosa==0.9.2
5
+ pandas==1.1.5
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.64.1
9
+ overrides==1.9.0
examples/vm_sound_classification8/run.sh ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ : <<'END'
4
+
5
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
6
+ --filename_patterns "E:/programmer/asr_datasets/voicemail/wav_finished/en-US/wav_finished/*/*.wav \
7
+ E:/programmer/asr_datasets/voicemail/wav_finished/id-ID/wav_finished/*/*.wav" \
8
+
9
+
10
+ sh run.sh --stage 3 --stop_stage 3 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
11
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
12
+
13
+ sh run.sh --stage 4 --stop_stage 4 --system_version windows --file_folder_name file_dir --final_model_name vm_sound_classification8 \
14
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
15
+
16
+ sh run.sh --stage 4 --stop_stage 4 --system_version centos --file_folder_name file_dir --final_model_name vm_sound_classification8 \
17
+ --filename_patterns "/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
18
+
19
+
20
+ "
21
+
22
+ END
23
+
24
+
25
+ # sh run.sh --stage -1 --stop_stage 9
26
+ # sh run.sh --stage -1 --stop_stage 5 --system_version centos --file_folder_name task_cnn_voicemail_id_id --final_model_name cnn_voicemail_id_id
27
+ # sh run.sh --stage 3 --stop_stage 4
28
+ # sh run.sh --stage 4 --stop_stage 4
29
+ # sh run.sh --stage 3 --stop_stage 3 --system_version centos --file_folder_name task_cnn_voicemail_id_id
30
+
31
+ # params
32
+ system_version="windows";
33
+ verbose=true;
34
+ stage=0 # start from 0 if you need to start from data preparation
35
+ stop_stage=9
36
+
37
+ work_dir="$(pwd)"
38
+ file_folder_name=file_folder_name
39
+ final_model_name=final_model_name
40
+ filename_patterns="/data/tianxing/PycharmProjects/datasets/voicemail/*/wav_finished/*/*.wav"
41
+ nohup_name=nohup.out
42
+
43
+ country=en-US
44
+
45
+ # model params
46
+ batch_size=64
47
+ max_epochs=200
48
+ save_top_k=10
49
+ patience=5
50
+
51
+
52
+ # parse options
53
+ while true; do
54
+ [ -z "${1:-}" ] && break; # break if there are no arguments
55
+ case "$1" in
56
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
57
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
58
+ old_value="(eval echo \\$$name)";
59
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
60
+ was_bool=true;
61
+ else
62
+ was_bool=false;
63
+ fi
64
+
65
+ # Set the variable to the right value-- the escaped quotes make it work if
66
+ # the option had spaces, like --cmd "queue.pl -sync y"
67
+ eval "${name}=\"$2\"";
68
+
69
+ # Check that Boolean-valued arguments are really Boolean.
70
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
71
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
72
+ exit 1;
73
+ fi
74
+ shift 2;
75
+ ;;
76
+
77
+ *) break;
78
+ esac
79
+ done
80
+
81
+ file_dir="${work_dir}/${file_folder_name}"
82
+ final_model_dir="${work_dir}/../../trained_models/${final_model_name}";
83
+
84
+ train_dataset="${file_dir}/train.xlsx"
85
+ valid_dataset="${file_dir}/valid.xlsx"
86
+ vocabulary_dir="${file_dir}/vocabulary"
87
+
88
+
89
+ $verbose && echo "system_version: ${system_version}"
90
+ $verbose && echo "file_folder_name: ${file_folder_name}"
91
+
92
+ if [ $system_version == "windows" ]; then
93
+ alias python3='D:/Users/tianx/PycharmProjects/virtualenv/vm_sound_classification/Scripts/python.exe'
94
+ elif [ $system_version == "centos" ] || [ $system_version == "ubuntu" ]; then
95
+ #source /data/local/bin/vm_sound_classification/bin/activate
96
+ alias python3='/data/local/bin/vm_sound_classification/bin/python3'
97
+ fi
98
+
99
+
100
+ if [ ${stage} -le 0 ] && [ ${stop_stage} -ge 0 ]; then
101
+ $verbose && echo "stage 0: prepare data"
102
+ cd "${work_dir}" || exit 1
103
+ python3 step_1_prepare_data.py \
104
+ --file_dir "${file_dir}" \
105
+ --filename_patterns "${filename_patterns}" \
106
+ --train_dataset "${train_dataset}" \
107
+ --valid_dataset "${valid_dataset}" \
108
+
109
+ fi
110
+
111
+
112
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
113
+ $verbose && echo "stage 1: make vocabulary"
114
+ cd "${work_dir}" || exit 1
115
+ python3 step_2_make_vocabulary.py \
116
+ --vocabulary_dir "${vocabulary_dir}" \
117
+ --train_dataset "${train_dataset}" \
118
+ --valid_dataset "${valid_dataset}" \
119
+
120
+ fi
121
+
122
+
123
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
124
+ $verbose && echo "stage 2: train global model"
125
+ cd "${work_dir}" || exit 1
126
+ python3 step_3_train_global_model.py \
127
+ --vocabulary_dir "${vocabulary_dir}" \
128
+ --train_dataset "${train_dataset}" \
129
+ --valid_dataset "${valid_dataset}" \
130
+ --serialization_dir "${file_dir}/global_model" \
131
+
132
+ fi
133
+
134
+
135
+ if [ ${stage} -le 3 ] && [ ${stop_stage} -ge 3 ]; then
136
+ $verbose && echo "stage 3: train country model"
137
+ cd "${work_dir}" || exit 1
138
+ python3 step_4_train_country_model.py \
139
+ --vocabulary_dir "${vocabulary_dir}" \
140
+ --train_dataset "${train_dataset}" \
141
+ --valid_dataset "${valid_dataset}" \
142
+ --country "${country}" \
143
+ --serialization_dir "${file_dir}/country_model" \
144
+
145
+ fi
146
+
147
+
148
+ if [ ${stage} -le 4 ] && [ ${stop_stage} -ge 4 ]; then
149
+ $verbose && echo "stage 4: train union model"
150
+ cd "${work_dir}" || exit 1
151
+ python3 step_5_train_union.py \
152
+ --vocabulary_dir "${vocabulary_dir}" \
153
+ --train_dataset "${train_dataset}" \
154
+ --valid_dataset "${valid_dataset}" \
155
+ --serialization_dir "${file_dir}/union" \
156
+
157
+ fi
examples/vm_sound_classification8/step_1_prepare_data.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from glob import glob
5
+ import json
6
+ import os
7
+ from pathlib import Path
8
+ import random
9
+ import sys
10
+
11
+ pwd = os.path.abspath(os.path.dirname(__file__))
12
+ sys.path.append(os.path.join(pwd, "../../"))
13
+
14
+ import pandas as pd
15
+ from scipy.io import wavfile
16
+ from tqdm import tqdm
17
+
18
+
19
+ def get_args():
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument("--file_dir", default="./", type=str)
22
+ parser.add_argument("--task", default="default", type=str)
23
+ parser.add_argument("--filename_patterns", type=str)
24
+
25
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
26
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
27
+
28
+ args = parser.parse_args()
29
+ return args
30
+
31
+
32
+ def get_dataset(args):
33
+ filename_patterns = args.filename_patterns
34
+ filename_patterns = filename_patterns.split(" ")
35
+ print(filename_patterns)
36
+
37
+ file_dir = Path(args.file_dir)
38
+ file_dir.mkdir(exist_ok=True)
39
+
40
+ global_label_map = {
41
+ "bell": "bell",
42
+ "white_noise": "white_noise",
43
+ "low_white_noise": "white_noise",
44
+ "high_white_noise": "noise",
45
+ "music": "music",
46
+ "mute": "mute",
47
+ "noise": "noise",
48
+ "noise_mute": "noise_mute",
49
+ "voice": "voice",
50
+ "voicemail": "voicemail",
51
+ }
52
+
53
+ country_label_map = {
54
+ "bell": "voicemail",
55
+ "white_noise": "non_voicemail",
56
+ "low_white_noise": "non_voicemail",
57
+ "hight_white_noise": "non_voicemail",
58
+ "music": "non_voicemail",
59
+ "mute": "non_voicemail",
60
+ "noise": "non_voicemail",
61
+ "noise_mute": "non_voicemail",
62
+ "voice": "non_voicemail",
63
+ "voicemail": "voicemail",
64
+ "non_voicemail": "non_voicemail",
65
+ }
66
+
67
+ result = list()
68
+ for filename_pattern in filename_patterns:
69
+ filename_list = glob(filename_pattern)
70
+ for filename in tqdm(filename_list):
71
+ filename = Path(filename)
72
+ sample_rate, signal = wavfile.read(filename.as_posix())
73
+ if len(signal) < sample_rate * 2:
74
+ continue
75
+
76
+ folder = filename.parts[-2]
77
+ country = filename.parts[-4]
78
+
79
+ if folder not in global_label_map.keys():
80
+ continue
81
+ if folder not in country_label_map.keys():
82
+ continue
83
+
84
+ global_label = global_label_map[folder]
85
+ country_label = country_label_map[folder]
86
+
87
+ random1 = random.random()
88
+ random2 = random.random()
89
+
90
+ result.append({
91
+ "filename": filename,
92
+ "folder": folder,
93
+ "category": country,
94
+ "global_labels": global_label,
95
+ "country_labels": country_label,
96
+ "random1": random1,
97
+ "random2": random2,
98
+ "flag": "TRAIN" if random2 < 0.8 else "TEST",
99
+ })
100
+
101
+ df = pd.DataFrame(result)
102
+ pivot_table = pd.pivot_table(df, index=["global_labels"], values=["filename"], aggfunc="count")
103
+ print(pivot_table)
104
+
105
+ df = df.sort_values(by=["random1"], ascending=False)
106
+ df.to_excel(
107
+ file_dir / "dataset.xlsx",
108
+ index=False,
109
+ # encoding="utf_8_sig"
110
+ )
111
+
112
+ return
113
+
114
+
115
+ def split_dataset(args):
116
+ """分割训练集, 测试集"""
117
+ file_dir = Path(args.file_dir)
118
+ file_dir.mkdir(exist_ok=True)
119
+
120
+ df = pd.read_excel(file_dir / "dataset.xlsx")
121
+
122
+ train = list()
123
+ test = list()
124
+
125
+ for i, row in df.iterrows():
126
+ flag = row["flag"]
127
+ if flag == "TRAIN":
128
+ train.append(row)
129
+ else:
130
+ test.append(row)
131
+
132
+ train = pd.DataFrame(train)
133
+ train.to_excel(
134
+ args.train_dataset,
135
+ index=False,
136
+ # encoding="utf_8_sig"
137
+ )
138
+ test = pd.DataFrame(test)
139
+ test.to_excel(
140
+ args.valid_dataset,
141
+ index=False,
142
+ # encoding="utf_8_sig"
143
+ )
144
+
145
+ return
146
+
147
+
148
+ def main():
149
+ args = get_args()
150
+ get_dataset(args)
151
+ split_dataset(args)
152
+ return
153
+
154
+
155
+ if __name__ == "__main__":
156
+ main()
examples/vm_sound_classification8/step_2_make_vocabulary.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ import os
5
+ from pathlib import Path
6
+ import sys
7
+
8
+ pwd = os.path.abspath(os.path.dirname(__file__))
9
+ sys.path.append(os.path.join(pwd, "../../"))
10
+
11
+ import pandas as pd
12
+
13
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
14
+
15
+
16
+ def get_args():
17
+ parser = argparse.ArgumentParser()
18
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
19
+
20
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
21
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
22
+
23
+ args = parser.parse_args()
24
+ return args
25
+
26
+
27
+ def main():
28
+ args = get_args()
29
+
30
+ train_dataset = pd.read_excel(args.train_dataset)
31
+ valid_dataset = pd.read_excel(args.valid_dataset)
32
+
33
+ # non_padded_namespaces
34
+ category_set = set()
35
+ for i, row in train_dataset.iterrows():
36
+ category = row["category"]
37
+ category_set.add(category)
38
+
39
+ for i, row in valid_dataset.iterrows():
40
+ category = row["category"]
41
+ category_set.add(category)
42
+
43
+ vocabulary = Vocabulary(non_padded_namespaces=["global_labels", *list(category_set)])
44
+
45
+ # train
46
+ for i, row in train_dataset.iterrows():
47
+ global_labels = row["global_labels"]
48
+ country_labels = row["country_labels"]
49
+ category = row["category"]
50
+
51
+ vocabulary.add_token_to_namespace(global_labels, "global_labels")
52
+ vocabulary.add_token_to_namespace(country_labels, category)
53
+
54
+ # valid
55
+ for i, row in valid_dataset.iterrows():
56
+ global_labels = row["global_labels"]
57
+ country_labels = row["country_labels"]
58
+ category = row["category"]
59
+
60
+ vocabulary.add_token_to_namespace(global_labels, "global_labels")
61
+ vocabulary.add_token_to_namespace(country_labels, category)
62
+
63
+ vocabulary.save_to_files(args.vocabulary_dir)
64
+
65
+ return
66
+
67
+
68
+ if __name__ == "__main__":
69
+ main()
examples/vm_sound_classification8/step_3_train_global_model.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 之前的代码达到准确率0.8423
5
+ 此代码达到准确率0.8379
6
+ 此代码可行.
7
+ """
8
+ import argparse
9
+ import copy
10
+ import json
11
+ import logging
12
+ from logging.handlers import TimedRotatingFileHandler
13
+ import os
14
+ from pathlib import Path
15
+ import platform
16
+ import sys
17
+ from typing import List
18
+
19
+ pwd = os.path.abspath(os.path.dirname(__file__))
20
+ sys.path.append(os.path.join(pwd, "../../"))
21
+
22
+ import torch
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
27
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
28
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
29
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
30
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
36
+
37
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
38
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
39
+
40
+ parser.add_argument("--max_epochs", default=100, type=int)
41
+ parser.add_argument("--batch_size", default=64, type=int)
42
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
43
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
44
+ parser.add_argument("--patience", default=5, type=int)
45
+ parser.add_argument("--serialization_dir", default="global_classifier", type=str)
46
+ parser.add_argument("--seed", default=0, type=int)
47
+
48
+ args = parser.parse_args()
49
+ return args
50
+
51
+
52
+ def logging_config(file_dir: str):
53
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
54
+
55
+ logging.basicConfig(format=fmt,
56
+ datefmt="%m/%d/%Y %H:%M:%S",
57
+ level=logging.DEBUG)
58
+ file_handler = TimedRotatingFileHandler(
59
+ filename=os.path.join(file_dir, "main.log"),
60
+ encoding="utf-8",
61
+ when="D",
62
+ interval=1,
63
+ backupCount=7
64
+ )
65
+ file_handler.setLevel(logging.INFO)
66
+ file_handler.setFormatter(logging.Formatter(fmt))
67
+ logger = logging.getLogger(__name__)
68
+ logger.addHandler(file_handler)
69
+
70
+ return logger
71
+
72
+
73
+ class CollateFunction(object):
74
+ def __init__(self):
75
+ pass
76
+
77
+ def __call__(self, batch: List[dict]):
78
+ array_list = list()
79
+ label_list = list()
80
+ for sample in batch:
81
+ array = sample["waveform"]
82
+ label = sample["label"]
83
+
84
+ array_list.append(array)
85
+ label_list.append(label)
86
+
87
+ array_list = torch.stack(array_list)
88
+ label_list = torch.stack(label_list)
89
+ return array_list, label_list
90
+
91
+
92
+ collate_fn = CollateFunction()
93
+
94
+
95
+ def main():
96
+ args = get_args()
97
+
98
+ serialization_dir = Path(args.serialization_dir)
99
+ serialization_dir.mkdir(parents=True, exist_ok=True)
100
+
101
+ logger = logging_config(args.serialization_dir)
102
+
103
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
104
+ n_gpu = torch.cuda.device_count()
105
+ logger.info("GPU available: {}; device: {}".format(n_gpu, device))
106
+
107
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
108
+
109
+ # datasets
110
+ train_dataset = WaveClassifierExcelDataset(
111
+ vocab=vocabulary,
112
+ excel_file=args.train_dataset,
113
+ category=None,
114
+ category_field="category",
115
+ label_field="global_labels",
116
+ expected_sample_rate=8000,
117
+ max_wave_value=32768.0,
118
+ )
119
+ valid_dataset = WaveClassifierExcelDataset(
120
+ vocab=vocabulary,
121
+ excel_file=args.valid_dataset,
122
+ category=None,
123
+ category_field="category",
124
+ label_field="global_labels",
125
+ expected_sample_rate=8000,
126
+ max_wave_value=32768.0,
127
+ )
128
+
129
+ train_data_loader = DataLoader(
130
+ dataset=train_dataset,
131
+ batch_size=args.batch_size,
132
+ shuffle=True,
133
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
134
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
135
+ collate_fn=collate_fn,
136
+ pin_memory=False,
137
+ # prefetch_factor=64,
138
+ )
139
+ valid_data_loader = DataLoader(
140
+ dataset=valid_dataset,
141
+ batch_size=args.batch_size,
142
+ shuffle=True,
143
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
144
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
145
+ collate_fn=collate_fn,
146
+ pin_memory=False,
147
+ # prefetch_factor=64,
148
+ )
149
+
150
+ # models - classifier
151
+ wave_encoder = WaveEncoder(
152
+ conv1d_block_param_list=[
153
+ {
154
+ 'batch_norm': True,
155
+ 'in_channels': 80,
156
+ 'out_channels': 16,
157
+ 'kernel_size': 3,
158
+ 'stride': 3,
159
+ # 'padding': 'same',
160
+ 'activation': 'relu',
161
+ 'dropout': 0.1,
162
+ },
163
+ {
164
+ # 'batch_norm': True,
165
+ 'in_channels': 16,
166
+ 'out_channels': 16,
167
+ 'kernel_size': 3,
168
+ 'stride': 3,
169
+ # 'padding': 'same',
170
+ 'activation': 'relu',
171
+ 'dropout': 0.1,
172
+ },
173
+ {
174
+ # 'batch_norm': True,
175
+ 'in_channels': 16,
176
+ 'out_channels': 16,
177
+ 'kernel_size': 3,
178
+ 'stride': 3,
179
+ # 'padding': 'same',
180
+ 'activation': 'relu',
181
+ 'dropout': 0.1,
182
+ },
183
+ ],
184
+ mel_spectrogram_param={
185
+ "sample_rate": 8000,
186
+ "n_fft": 512,
187
+ "win_length": 200,
188
+ "hop_length": 80,
189
+ "f_min": 10,
190
+ "f_max": 3800,
191
+ "window_fn": "hamming",
192
+ "n_mels": 80,
193
+ }
194
+ )
195
+ cls_head = ClsHead(
196
+ input_dim=16,
197
+ num_layers=2,
198
+ hidden_dims=[32, 16],
199
+ activations="relu",
200
+ dropout=0.1,
201
+ num_labels=vocabulary.get_vocab_size(namespace="global_labels")
202
+ )
203
+ model = WaveClassifier(
204
+ wave_encoder=wave_encoder,
205
+ cls_head=cls_head,
206
+ )
207
+ model.to(device)
208
+
209
+ # optimizer
210
+ optimizer = torch.optim.Adam(
211
+ model.parameters(),
212
+ lr=args.learning_rate
213
+ )
214
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
215
+ optimizer,
216
+ step_size=30000
217
+ )
218
+ focal_loss = FocalLoss(
219
+ num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
220
+ reduction="mean",
221
+ )
222
+ categorical_accuracy = CategoricalAccuracy()
223
+
224
+ # training
225
+ best_idx_epoch: int = None
226
+ best_accuracy: float = None
227
+ patience_count = 0
228
+ global_step = 0
229
+ model_filename_list = list()
230
+ for idx_epoch in range(args.max_epochs):
231
+
232
+ # training
233
+ model.train()
234
+ total_loss = 0
235
+ total_examples = 0
236
+ for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
237
+ input_ids, label_ids = batch
238
+ input_ids = input_ids.to(device)
239
+ label_ids: torch.LongTensor = label_ids.to(device).long()
240
+
241
+ logits = model.forward(input_ids)
242
+ loss = focal_loss.forward(logits, label_ids.view(-1))
243
+ categorical_accuracy(logits, label_ids)
244
+
245
+ total_loss += loss.item()
246
+ total_examples += input_ids.size(0)
247
+
248
+ optimizer.zero_grad()
249
+ loss.backward()
250
+ optimizer.step()
251
+ lr_scheduler.step()
252
+
253
+ global_step += 1
254
+ training_loss = total_loss / total_examples
255
+ training_loss = round(training_loss, 4)
256
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
257
+ training_accuracy = round(training_accuracy, 4)
258
+ logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
259
+ idx_epoch, training_loss, training_accuracy
260
+ ))
261
+
262
+ # evaluation
263
+ model.eval()
264
+ total_loss = 0
265
+ total_examples = 0
266
+ for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
267
+ input_ids, label_ids = batch
268
+ input_ids = input_ids.to(device)
269
+ label_ids: torch.LongTensor = label_ids.to(device).long()
270
+
271
+ with torch.no_grad():
272
+ logits = model.forward(input_ids)
273
+ loss = focal_loss.forward(logits, label_ids.view(-1))
274
+ categorical_accuracy(logits, label_ids)
275
+
276
+ total_loss += loss.item()
277
+ total_examples += input_ids.size(0)
278
+
279
+ evaluation_loss = total_loss / total_examples
280
+ evaluation_loss = round(evaluation_loss, 4)
281
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
282
+ evaluation_accuracy = round(evaluation_accuracy, 4)
283
+ logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
284
+ idx_epoch, evaluation_loss, evaluation_accuracy
285
+ ))
286
+
287
+ # save metric
288
+ metrics = {
289
+ "training_loss": training_loss,
290
+ "training_accuracy": training_accuracy,
291
+ "evaluation_loss": evaluation_loss,
292
+ "evaluation_accuracy": evaluation_accuracy,
293
+ "best_idx_epoch": best_idx_epoch,
294
+ "best_accuracy": best_accuracy,
295
+ }
296
+ metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
297
+ with open(metrics_filename, "w", encoding="utf-8") as f:
298
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
299
+
300
+ # save model
301
+ model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
302
+ model_filename_list.append(model_filename)
303
+ if len(model_filename_list) >= args.num_serialized_models_to_keep:
304
+ model_filename_to_delete = model_filename_list.pop(0)
305
+ os.remove(model_filename_to_delete)
306
+ torch.save(model.state_dict(), model_filename)
307
+
308
+ # early stop
309
+ best_model_filename = os.path.join(args.serialization_dir, "best.bin")
310
+ if best_accuracy is None:
311
+ best_idx_epoch = idx_epoch
312
+ best_accuracy = evaluation_accuracy
313
+ torch.save(model.state_dict(), best_model_filename)
314
+ elif evaluation_accuracy > best_accuracy:
315
+ best_idx_epoch = idx_epoch
316
+ best_accuracy = evaluation_accuracy
317
+ torch.save(model.state_dict(), best_model_filename)
318
+ patience_count = 0
319
+ elif patience_count >= args.patience:
320
+ break
321
+ else:
322
+ patience_count += 1
323
+
324
+ return
325
+
326
+
327
+ if __name__ == "__main__":
328
+ main()
examples/vm_sound_classification8/step_4_train_country_model.py ADDED
@@ -0,0 +1,349 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ """
4
+ 只训练 cls_head 部分的参数, 模型的准确率会更低.
5
+ """
6
+ import argparse
7
+ from collections import defaultdict
8
+ import json
9
+ import logging
10
+ from logging.handlers import TimedRotatingFileHandler
11
+ import os
12
+ import platform
13
+ from pathlib import Path
14
+ import sys
15
+ import shutil
16
+ from typing import List
17
+
18
+ pwd = os.path.abspath(os.path.dirname(__file__))
19
+ sys.path.append(os.path.join(pwd, "../../"))
20
+
21
+ import pandas as pd
22
+ import torch
23
+ from torch.utils.data.dataloader import DataLoader
24
+ from tqdm import tqdm
25
+
26
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
27
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
28
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
29
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
30
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
31
+
32
+
33
+ def get_args():
34
+ parser = argparse.ArgumentParser()
35
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
36
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
37
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
38
+
39
+ parser.add_argument("--country", default="en-US", type=str)
40
+ parser.add_argument("--shared_encoder", default="file_dir/global_model/best.bin", type=str)
41
+
42
+ parser.add_argument("--max_epochs", default=100, type=int)
43
+ parser.add_argument("--batch_size", default=64, type=int)
44
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
45
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
46
+ parser.add_argument("--patience", default=5, type=int)
47
+ parser.add_argument("--serialization_dir", default="country_models", type=str)
48
+ parser.add_argument("--seed", default=0, type=int)
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ def logging_config(file_dir: str):
55
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
56
+
57
+ logging.basicConfig(format=fmt,
58
+ datefmt="%m/%d/%Y %H:%M:%S",
59
+ level=logging.DEBUG)
60
+ file_handler = TimedRotatingFileHandler(
61
+ filename=os.path.join(file_dir, "main.log"),
62
+ encoding="utf-8",
63
+ when="D",
64
+ interval=1,
65
+ backupCount=7
66
+ )
67
+ file_handler.setLevel(logging.INFO)
68
+ file_handler.setFormatter(logging.Formatter(fmt))
69
+ logger = logging.getLogger(__name__)
70
+ logger.addHandler(file_handler)
71
+
72
+ return logger
73
+
74
+
75
+ class CollateFunction(object):
76
+ def __init__(self):
77
+ pass
78
+
79
+ def __call__(self, batch: List[dict]):
80
+ array_list = list()
81
+ label_list = list()
82
+ for sample in batch:
83
+ array = sample['waveform']
84
+ label = sample['label']
85
+
86
+ array_list.append(array)
87
+ label_list.append(label)
88
+
89
+ array_list = torch.stack(array_list)
90
+ label_list = torch.stack(label_list)
91
+ return array_list, label_list
92
+
93
+
94
+ collate_fn = CollateFunction()
95
+
96
+
97
+ def main():
98
+ args = get_args()
99
+
100
+ serialization_dir = Path(args.serialization_dir)
101
+ serialization_dir.mkdir(parents=True, exist_ok=True)
102
+
103
+ logger = logging_config(args.serialization_dir)
104
+
105
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
106
+ n_gpu = torch.cuda.device_count()
107
+ logger.info("GPU available: {}; device: {}".format(n_gpu, device))
108
+
109
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
110
+
111
+ # datasets
112
+ logger.info("prepare datasets")
113
+ train_dataset = WaveClassifierExcelDataset(
114
+ vocab=vocabulary,
115
+ excel_file=args.train_dataset,
116
+ category=args.country,
117
+ category_field="category",
118
+ label_field="country_labels",
119
+ expected_sample_rate=8000,
120
+ max_wave_value=32768.0,
121
+ )
122
+ valid_dataset = WaveClassifierExcelDataset(
123
+ vocab=vocabulary,
124
+ excel_file=args.valid_dataset,
125
+ category=args.country,
126
+ category_field="category",
127
+ label_field="country_labels",
128
+ expected_sample_rate=8000,
129
+ max_wave_value=32768.0,
130
+ )
131
+
132
+ train_data_loader = DataLoader(
133
+ dataset=train_dataset,
134
+ batch_size=args.batch_size,
135
+ shuffle=True,
136
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
137
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
138
+ collate_fn=collate_fn,
139
+ pin_memory=False,
140
+ # prefetch_factor=64,
141
+ )
142
+ valid_data_loader = DataLoader(
143
+ dataset=valid_dataset,
144
+ batch_size=args.batch_size,
145
+ shuffle=True,
146
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
147
+ num_workers=0 if platform.system() == "Windows" else os.cpu_count(),
148
+ collate_fn=collate_fn,
149
+ pin_memory=False,
150
+ # prefetch_factor=64,
151
+ )
152
+
153
+ # models - classifier
154
+ wave_encoder = WaveEncoder(
155
+ conv1d_block_param_list=[
156
+ {
157
+ 'batch_norm': True,
158
+ 'in_channels': 80,
159
+ 'out_channels': 16,
160
+ 'kernel_size': 3,
161
+ 'stride': 3,
162
+ # 'padding': 'same',
163
+ 'activation': 'relu',
164
+ 'dropout': 0.1,
165
+ },
166
+ {
167
+ # 'batch_norm': True,
168
+ 'in_channels': 16,
169
+ 'out_channels': 16,
170
+ 'kernel_size': 3,
171
+ 'stride': 3,
172
+ # 'padding': 'same',
173
+ 'activation': 'relu',
174
+ 'dropout': 0.1,
175
+ },
176
+ {
177
+ # 'batch_norm': True,
178
+ 'in_channels': 16,
179
+ 'out_channels': 16,
180
+ 'kernel_size': 3,
181
+ 'stride': 3,
182
+ # 'padding': 'same',
183
+ 'activation': 'relu',
184
+ 'dropout': 0.1,
185
+ },
186
+ ],
187
+ mel_spectrogram_param={
188
+ "sample_rate": 8000,
189
+ "n_fft": 512,
190
+ "win_length": 200,
191
+ "hop_length": 80,
192
+ "f_min": 10,
193
+ "f_max": 3800,
194
+ "window_fn": "hamming",
195
+ "n_mels": 80,
196
+ }
197
+ )
198
+
199
+ with open(args.shared_encoder, "rb") as f:
200
+ state_dict = torch.load(f, map_location=device)
201
+ processed_state_dict = dict()
202
+ prefix = "wave_encoder."
203
+ for k, v in state_dict.items():
204
+ if not str(k).startswith(prefix):
205
+ continue
206
+ k = k[len(prefix):]
207
+ processed_state_dict[k] = v
208
+
209
+ wave_encoder.load_state_dict(
210
+ state_dict=processed_state_dict,
211
+ strict=True,
212
+ )
213
+ cls_head = ClsHead(
214
+ input_dim=16,
215
+ num_layers=2,
216
+ hidden_dims=[32, 16],
217
+ activations="relu",
218
+ dropout=0.1,
219
+ num_labels=vocabulary.get_vocab_size(namespace="global_labels")
220
+ )
221
+ model = WaveClassifier(
222
+ wave_encoder=wave_encoder,
223
+ cls_head=cls_head,
224
+ )
225
+ model.wave_encoder.requires_grad_(requires_grad=False)
226
+ model.cls_head.requires_grad_(requires_grad=True)
227
+ model.to(device)
228
+
229
+ # optimizer
230
+ logger.info("prepare optimizer")
231
+ optimizer = torch.optim.Adam(
232
+ model.cls_head.parameters(),
233
+ lr=args.learning_rate,
234
+ )
235
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
236
+ optimizer,
237
+ step_size=2000
238
+ )
239
+ focal_loss = FocalLoss(
240
+ num_classes=vocabulary.get_vocab_size(namespace=args.country),
241
+ reduction="mean",
242
+ )
243
+ categorical_accuracy = CategoricalAccuracy()
244
+
245
+ # training loop
246
+ best_idx_epoch: int = None
247
+ best_accuracy: float = None
248
+ patience_count = 0
249
+ global_step = 0
250
+ model_filename_list = list()
251
+ for idx_epoch in range(args.max_epochs):
252
+
253
+ # training
254
+ model.train()
255
+ total_loss = 0
256
+ total_examples = 0
257
+ for step, batch in enumerate(tqdm(train_data_loader, desc="Epoch={} (training)".format(idx_epoch))):
258
+ input_ids, label_ids = batch
259
+ input_ids = input_ids.to(device)
260
+ label_ids: torch.LongTensor = label_ids.to(device).long()
261
+
262
+ logits = model.forward(input_ids)
263
+ loss = focal_loss.forward(logits, label_ids.view(-1))
264
+ categorical_accuracy(logits, label_ids)
265
+
266
+ total_loss += loss.item()
267
+ total_examples += input_ids.size(0)
268
+
269
+ optimizer.zero_grad()
270
+ loss.backward()
271
+ optimizer.step()
272
+ lr_scheduler.step()
273
+
274
+ global_step += 1
275
+ training_loss = total_loss / total_examples
276
+ training_loss = round(training_loss, 4)
277
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
278
+ training_accuracy = round(training_accuracy, 4)
279
+ logger.info("Epoch: {}; training_loss: {}; training_accuracy: {}".format(
280
+ idx_epoch, training_loss, training_accuracy
281
+ ))
282
+
283
+ # evaluation
284
+ model.eval()
285
+ total_loss = 0
286
+ total_examples = 0
287
+ for step, batch in enumerate(tqdm(valid_data_loader, desc="Epoch={} (evaluation)".format(idx_epoch))):
288
+ input_ids, label_ids = batch
289
+ input_ids = input_ids.to(device)
290
+ label_ids: torch.LongTensor = label_ids.to(device).long()
291
+
292
+ with torch.no_grad():
293
+ logits = model.forward(input_ids)
294
+ loss = focal_loss.forward(logits, label_ids.view(-1))
295
+ categorical_accuracy(logits, label_ids)
296
+
297
+ total_loss += loss.item()
298
+ total_examples += input_ids.size(0)
299
+
300
+ evaluation_loss = total_loss / total_examples
301
+ evaluation_loss = round(evaluation_loss, 4)
302
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
303
+ evaluation_accuracy = round(evaluation_accuracy, 4)
304
+ logger.info("Epoch: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
305
+ idx_epoch, evaluation_loss, evaluation_accuracy
306
+ ))
307
+
308
+ # save metric
309
+ metrics = {
310
+ "training_loss": training_loss,
311
+ "training_accuracy": training_accuracy,
312
+ "evaluation_loss": evaluation_loss,
313
+ "evaluation_accuracy": evaluation_accuracy,
314
+ "best_idx_epoch": best_idx_epoch,
315
+ "best_accuracy": best_accuracy,
316
+ }
317
+ metrics_filename = os.path.join(args.serialization_dir, "metrics_epoch_{}.json".format(idx_epoch))
318
+ with open(metrics_filename, "w", encoding="utf-8") as f:
319
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
320
+
321
+ # save model
322
+ model_filename = os.path.join(args.serialization_dir, "model_epoch_{}.bin".format(idx_epoch))
323
+ model_filename_list.append(model_filename)
324
+ if len(model_filename_list) >= args.num_serialized_models_to_keep:
325
+ model_filename_to_delete = model_filename_list.pop(0)
326
+ os.remove(model_filename_to_delete)
327
+ torch.save(model.state_dict(), model_filename)
328
+
329
+ # early stop
330
+ best_model_filename = os.path.join(args.serialization_dir, "best.bin")
331
+ if best_accuracy is None:
332
+ best_idx_epoch = idx_epoch
333
+ best_accuracy = evaluation_accuracy
334
+ torch.save(model.state_dict(), best_model_filename)
335
+ elif evaluation_accuracy > best_accuracy:
336
+ best_idx_epoch = idx_epoch
337
+ best_accuracy = evaluation_accuracy
338
+ torch.save(model.state_dict(), best_model_filename)
339
+ patience_count = 0
340
+ elif patience_count >= args.patience:
341
+ break
342
+ else:
343
+ patience_count += 1
344
+
345
+ return
346
+
347
+
348
+ if __name__ == "__main__":
349
+ main()
examples/vm_sound_classification8/step_5_train_union.py ADDED
@@ -0,0 +1,499 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from collections import defaultdict
5
+ import json
6
+ import logging
7
+ from logging.handlers import TimedRotatingFileHandler
8
+ import os
9
+ import platform
10
+ from pathlib import Path
11
+ import sys
12
+ import shutil
13
+ from typing import List
14
+
15
+ pwd = os.path.abspath(os.path.dirname(__file__))
16
+ sys.path.append(os.path.join(pwd, "../../"))
17
+
18
+ import pandas as pd
19
+ import torch
20
+ from torch.utils.data.dataloader import DataLoader
21
+ from tqdm import tqdm
22
+
23
+ from toolbox.torch.modules.loss import FocalLoss, HingeLoss, HingeLinear
24
+ from toolbox.torch.training.metrics.categorical_accuracy import CategoricalAccuracy
25
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
26
+ from toolbox.torch.utils.data.dataset.wave_classifier_excel_dataset import WaveClassifierExcelDataset
27
+ from toolbox.torchaudio.models.cnn_audio_classifier.modeling_cnn_audio_classifier import WaveEncoder, ClsHead, WaveClassifier
28
+
29
+
30
+ def get_args():
31
+ parser = argparse.ArgumentParser()
32
+ parser.add_argument("--vocabulary_dir", default="vocabulary", type=str)
33
+
34
+ parser.add_argument("--train_dataset", default="train.xlsx", type=str)
35
+ parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
36
+
37
+ parser.add_argument("--max_steps", default=100000, type=int)
38
+ parser.add_argument("--save_steps", default=30, type=int)
39
+
40
+ parser.add_argument("--batch_size", default=1, type=int)
41
+ parser.add_argument("--learning_rate", default=1e-3, type=float)
42
+ parser.add_argument("--num_serialized_models_to_keep", default=10, type=int)
43
+ parser.add_argument("--patience", default=5, type=int)
44
+ parser.add_argument("--serialization_dir", default="union", type=str)
45
+ parser.add_argument("--seed", default=0, type=int)
46
+
47
+ parser.add_argument("--num_workers", default=0, type=int)
48
+
49
+ args = parser.parse_args()
50
+ return args
51
+
52
+
53
+ def logging_config(file_dir: str):
54
+ fmt = "%(asctime)s - %(name)s - %(levelname)s %(filename)s:%(lineno)d > %(message)s"
55
+
56
+ logging.basicConfig(format=fmt,
57
+ datefmt="%m/%d/%Y %H:%M:%S",
58
+ level=logging.DEBUG)
59
+ file_handler = TimedRotatingFileHandler(
60
+ filename=os.path.join(file_dir, "main.log"),
61
+ encoding="utf-8",
62
+ when="D",
63
+ interval=1,
64
+ backupCount=7
65
+ )
66
+ file_handler.setLevel(logging.INFO)
67
+ file_handler.setFormatter(logging.Formatter(fmt))
68
+ logger = logging.getLogger(__name__)
69
+ logger.addHandler(file_handler)
70
+
71
+ return logger
72
+
73
+
74
+ class CollateFunction(object):
75
+ def __init__(self):
76
+ pass
77
+
78
+ def __call__(self, batch: List[dict]):
79
+ array_list = list()
80
+ label_list = list()
81
+ for sample in batch:
82
+ array = sample['waveform']
83
+ label = sample['label']
84
+
85
+ array_list.append(array)
86
+ label_list.append(label)
87
+
88
+ array_list = torch.stack(array_list)
89
+ label_list = torch.stack(label_list)
90
+ return array_list, label_list
91
+
92
+
93
+ collate_fn = CollateFunction()
94
+
95
+
96
+ class DatasetIterator(object):
97
+ def __init__(self, data_loader: DataLoader):
98
+ self.data_loader = data_loader
99
+ self.data_loader_iter = iter(self.data_loader)
100
+
101
+ def next(self):
102
+ try:
103
+ result = self.data_loader_iter.__next__()
104
+ except StopIteration:
105
+ self.data_loader_iter = iter(self.data_loader)
106
+ result = self.data_loader_iter.__next__()
107
+ return result
108
+
109
+
110
+ def main():
111
+ args = get_args()
112
+
113
+ serialization_dir = Path(args.serialization_dir)
114
+ serialization_dir.mkdir(parents=True, exist_ok=True)
115
+
116
+ logger = logging_config(args.serialization_dir)
117
+
118
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
119
+ n_gpu = torch.cuda.device_count()
120
+ logger.info("GPU available count: {}; device: {}".format(n_gpu, device))
121
+
122
+ vocabulary = Vocabulary.from_files(args.vocabulary_dir)
123
+ namespaces = vocabulary._token_to_index.keys()
124
+
125
+ # namespace_to_ratio
126
+ max_radio = (len(namespaces) - 1) * 3
127
+ namespace_to_ratio = {n: 1 for n in namespaces}
128
+ namespace_to_ratio["global_labels"] = max_radio
129
+
130
+ # datasets
131
+ logger.info("prepare datasets")
132
+ namespace_to_datasets = dict()
133
+ for namespace in namespaces:
134
+ logger.info("prepare datasets - {}".format(namespace))
135
+ if namespace == "global_labels":
136
+ train_dataset = WaveClassifierExcelDataset(
137
+ vocab=vocabulary,
138
+ excel_file=args.train_dataset,
139
+ category=None,
140
+ category_field="category",
141
+ label_field="global_labels",
142
+ expected_sample_rate=8000,
143
+ max_wave_value=32768.0,
144
+ )
145
+ valid_dataset = WaveClassifierExcelDataset(
146
+ vocab=vocabulary,
147
+ excel_file=args.valid_dataset,
148
+ category=None,
149
+ category_field="category",
150
+ label_field="global_labels",
151
+ expected_sample_rate=8000,
152
+ max_wave_value=32768.0,
153
+ )
154
+ else:
155
+ train_dataset = WaveClassifierExcelDataset(
156
+ vocab=vocabulary,
157
+ excel_file=args.train_dataset,
158
+ category=namespace,
159
+ category_field="category",
160
+ label_field="country_labels",
161
+ expected_sample_rate=8000,
162
+ max_wave_value=32768.0,
163
+ )
164
+ valid_dataset = WaveClassifierExcelDataset(
165
+ vocab=vocabulary,
166
+ excel_file=args.valid_dataset,
167
+ category=namespace,
168
+ category_field="category",
169
+ label_field="country_labels",
170
+ expected_sample_rate=8000,
171
+ max_wave_value=32768.0,
172
+ )
173
+
174
+ train_data_loader = DataLoader(
175
+ dataset=train_dataset,
176
+ batch_size=args.batch_size,
177
+ shuffle=True,
178
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
179
+ # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
180
+ num_workers=args.num_workers,
181
+ collate_fn=collate_fn,
182
+ pin_memory=False,
183
+ # prefetch_factor=64,
184
+ )
185
+ valid_data_loader = DataLoader(
186
+ dataset=valid_dataset,
187
+ batch_size=args.batch_size,
188
+ shuffle=True,
189
+ # Linux 系统中可以使用多个子进程加载数据, 而在 Windows 系统中不能.
190
+ # num_workers=0 if platform.system() == "Windows" else os.cpu_count() // 2,
191
+ num_workers=args.num_workers,
192
+ collate_fn=collate_fn,
193
+ pin_memory=False,
194
+ # prefetch_factor=64,
195
+ )
196
+
197
+ namespace_to_datasets[namespace] = {
198
+ "train_data_loader": train_data_loader,
199
+ "valid_data_loader": valid_data_loader,
200
+ }
201
+
202
+ # datasets iterator
203
+ logger.info("prepare datasets iterator")
204
+ namespace_to_datasets_iter = dict()
205
+ for namespace in namespaces:
206
+ logger.info("prepare datasets iterator - {}".format(namespace))
207
+ train_data_loader = namespace_to_datasets[namespace]["train_data_loader"]
208
+ valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
209
+ namespace_to_datasets_iter[namespace] = {
210
+ "train_data_loader_iter": DatasetIterator(train_data_loader),
211
+ "valid_data_loader_iter": DatasetIterator(valid_data_loader),
212
+ }
213
+
214
+ # models - encoder
215
+ logger.info("prepare models - encoder")
216
+ wave_encoder = WaveEncoder(
217
+ conv2d_block_param_list=[
218
+ {
219
+ "batch_norm": True,
220
+ "in_channels": 1,
221
+ "out_channels": 4,
222
+ "kernel_size": 3,
223
+ "stride": 1,
224
+ # "padding": "same",
225
+ "dilation": 3,
226
+ "activation": "relu",
227
+ "dropout": 0.1,
228
+ },
229
+ {
230
+ # "batch_norm": True,
231
+ "in_channels": 4,
232
+ "out_channels": 4,
233
+ "kernel_size": 5,
234
+ "stride": 2,
235
+ # "padding": "same",
236
+ "dilation": 3,
237
+ "activation": "relu",
238
+ "dropout": 0.1,
239
+ },
240
+ {
241
+ # "batch_norm": True,
242
+ "in_channels": 4,
243
+ "out_channels": 4,
244
+ "kernel_size": 3,
245
+ "stride": 1,
246
+ # "padding": "same",
247
+ "dilation": 2,
248
+ "activation": "relu",
249
+ "dropout": 0.1,
250
+ },
251
+ ],
252
+ mel_spectrogram_param={
253
+ 'sample_rate': 8000,
254
+ 'n_fft': 512,
255
+ 'win_length': 200,
256
+ 'hop_length': 80,
257
+ 'f_min': 10,
258
+ 'f_max': 3800,
259
+ 'window_fn': 'hamming',
260
+ 'n_mels': 80,
261
+ }
262
+ )
263
+
264
+ # models - cls_head
265
+ logger.info("prepare models - cls_head")
266
+ namespace_to_cls_heads = dict()
267
+ for namespace in namespaces:
268
+ logger.info("prepare models - cls_head - {}".format(namespace))
269
+ cls_head = ClsHead(
270
+ input_dim=352,
271
+ num_layers=2,
272
+ hidden_dims=[128, 32],
273
+ activations="relu",
274
+ dropout=0.1,
275
+ num_labels=vocabulary.get_vocab_size(namespace=namespace)
276
+ )
277
+ namespace_to_cls_heads[namespace] = cls_head
278
+
279
+ # models - classifier
280
+ logger.info("prepare models - classifier")
281
+ namespace_to_classifier = dict()
282
+ for namespace in namespaces:
283
+ logger.info("prepare models - classifier - {}".format(namespace))
284
+ cls_head = namespace_to_cls_heads[namespace]
285
+ wave_classifier = WaveClassifier(
286
+ wave_encoder=wave_encoder,
287
+ cls_head=cls_head,
288
+ )
289
+ wave_classifier.to(device)
290
+ namespace_to_classifier[namespace] = wave_classifier
291
+
292
+ # optimizer
293
+ logger.info("prepare optimizer")
294
+ param_optimizer = list()
295
+ param_optimizer.extend(wave_encoder.parameters())
296
+ for _, cls_head in namespace_to_cls_heads.items():
297
+ param_optimizer.extend(cls_head.parameters())
298
+
299
+ optimizer = torch.optim.Adam(
300
+ param_optimizer,
301
+ lr=args.learning_rate,
302
+ )
303
+ lr_scheduler = torch.optim.lr_scheduler.StepLR(
304
+ optimizer,
305
+ step_size=10000
306
+ )
307
+ focal_loss = FocalLoss(
308
+ num_classes=vocabulary.get_vocab_size(namespace="global_labels"),
309
+ reduction="mean",
310
+ )
311
+
312
+ # categorical_accuracy
313
+ logger.info("prepare categorical_accuracy")
314
+ namespace_to_categorical_accuracy = dict()
315
+ for namespace in namespaces:
316
+ categorical_accuracy = CategoricalAccuracy()
317
+ namespace_to_categorical_accuracy[namespace] = categorical_accuracy
318
+
319
+ # training loop
320
+ logger.info("prepare training loop")
321
+
322
+ model_list = list()
323
+ best_idx_step = None
324
+ best_accuracy = None
325
+ patience_count = 0
326
+
327
+ namespace_to_total_loss = defaultdict(float)
328
+ namespace_to_total_examples = defaultdict(int)
329
+ for idx_step in tqdm(range(args.max_steps)):
330
+
331
+ # training one step
332
+ loss: torch.Tensor = None
333
+ for namespace in namespaces:
334
+ train_data_loader_iter = namespace_to_datasets_iter[namespace]["train_data_loader_iter"]
335
+
336
+ ratio = namespace_to_ratio[namespace]
337
+ model = namespace_to_classifier[namespace]
338
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
339
+
340
+ model.train()
341
+
342
+ for _ in range(ratio):
343
+ batch = train_data_loader_iter.next()
344
+ input_ids, label_ids = batch
345
+ input_ids = input_ids.to(device)
346
+ label_ids: torch.LongTensor = label_ids.to(device).long()
347
+
348
+ logits = model.forward(input_ids)
349
+ task_loss = focal_loss.forward(logits, label_ids.view(-1))
350
+ categorical_accuracy(logits, label_ids)
351
+
352
+ if loss is None:
353
+ loss = task_loss
354
+ else:
355
+ loss += task_loss
356
+
357
+ namespace_to_total_loss[namespace] += task_loss.item()
358
+ namespace_to_total_examples[namespace] += input_ids.size(0)
359
+
360
+ optimizer.zero_grad()
361
+ loss.backward()
362
+ optimizer.step()
363
+ lr_scheduler.step()
364
+
365
+ # logging
366
+ if (idx_step + 1) % args.save_steps == 0:
367
+ metrics = dict()
368
+
369
+ # training
370
+ for namespace in namespaces:
371
+ total_loss = namespace_to_total_loss[namespace]
372
+ total_examples = namespace_to_total_examples[namespace]
373
+
374
+ training_loss = total_loss / total_examples
375
+ training_loss = round(training_loss, 4)
376
+
377
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
378
+
379
+ training_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
380
+ training_accuracy = round(training_accuracy, 4)
381
+ logger.info("Step: {}; namespace: {}; training_loss: {}; training_accuracy: {}".format(
382
+ idx_step, namespace, training_loss, training_accuracy
383
+ ))
384
+ metrics[namespace] = {
385
+ "training_loss": training_loss,
386
+ "training_accuracy": training_accuracy,
387
+ }
388
+ namespace_to_total_loss = defaultdict(float)
389
+ namespace_to_total_examples = defaultdict(int)
390
+
391
+ # evaluation
392
+ for namespace in namespaces:
393
+ valid_data_loader = namespace_to_datasets[namespace]["valid_data_loader"]
394
+
395
+ model = namespace_to_classifier[namespace]
396
+ categorical_accuracy = namespace_to_categorical_accuracy[namespace]
397
+
398
+ model.eval()
399
+
400
+ total_loss = 0
401
+ total_examples = 0
402
+ for step, batch in enumerate(valid_data_loader):
403
+ input_ids, label_ids = batch
404
+ input_ids = input_ids.to(device)
405
+ label_ids: torch.LongTensor = label_ids.to(device).long()
406
+
407
+ with torch.no_grad():
408
+ logits = model.forward(input_ids)
409
+ loss = focal_loss.forward(logits, label_ids.view(-1))
410
+ categorical_accuracy(logits, label_ids)
411
+
412
+ total_loss += loss.item()
413
+ total_examples += input_ids.size(0)
414
+
415
+ evaluation_loss = total_loss / total_examples
416
+ evaluation_loss = round(evaluation_loss, 4)
417
+ evaluation_accuracy = categorical_accuracy.get_metric(reset=True)["accuracy"]
418
+ evaluation_accuracy = round(evaluation_accuracy, 4)
419
+ logger.info("Step: {}; namespace: {}; evaluation_loss: {}; evaluation_accuracy: {}".format(
420
+ idx_step, namespace, evaluation_loss, evaluation_accuracy
421
+ ))
422
+ metrics[namespace] = {
423
+ "evaluation_loss": evaluation_loss,
424
+ "evaluation_accuracy": evaluation_accuracy,
425
+ }
426
+
427
+ # update ratio
428
+ min_accuracy = min([m["evaluation_accuracy"] for m in metrics.values()])
429
+ max_accuracy = max([m["evaluation_accuracy"] for m in metrics.values()])
430
+ width = max_accuracy - min_accuracy
431
+ for namespace, metric in metrics.items():
432
+ evaluation_accuracy = metric["evaluation_accuracy"]
433
+ radio = (max_accuracy - evaluation_accuracy) / width * max_radio
434
+ radio = int(radio)
435
+ namespace_to_ratio[namespace] = radio
436
+
437
+ msg = "".join(["{}: {}; ".format(k, v) for k, v in namespace_to_ratio.items()])
438
+ logger.info("namespace to ratio: {}".format(msg))
439
+
440
+ # save path
441
+ step_dir = serialization_dir / "step-{}".format(idx_step)
442
+ step_dir.mkdir(parents=True, exist_ok=False)
443
+
444
+ # save models
445
+ wave_encoder_filename = step_dir / "wave_encoder.pt"
446
+ torch.save(wave_encoder.state_dict(), wave_encoder_filename)
447
+ for namespace in namespaces:
448
+ cls_head_filename = step_dir / "{}.pt".format(namespace)
449
+ cls_head = namespace_to_cls_heads[namespace]
450
+ torch.save(cls_head.state_dict(), cls_head_filename)
451
+
452
+ model_list.append(step_dir)
453
+ if len(model_list) >= args.num_serialized_models_to_keep:
454
+ model_to_delete: Path = model_list.pop(0)
455
+ shutil.rmtree(model_to_delete.as_posix())
456
+
457
+ # save metric
458
+ this_accuracy = metrics["global_labels"]["evaluation_accuracy"]
459
+ if best_accuracy is None:
460
+ best_idx_step = idx_step
461
+ best_accuracy = this_accuracy
462
+ elif metrics["global_labels"]["evaluation_accuracy"] > best_accuracy:
463
+ best_idx_step = idx_step
464
+ best_accuracy = this_accuracy
465
+ else:
466
+ pass
467
+
468
+ metrics_filename = step_dir / "metrics_epoch.json"
469
+ metrics.update({
470
+ "idx_step": idx_step,
471
+ "best_idx_step": best_idx_step,
472
+ })
473
+ with open(metrics_filename, "w", encoding="utf-8") as f:
474
+ json.dump(metrics, f, indent=4, ensure_ascii=False)
475
+
476
+ # save best
477
+ best_dir = serialization_dir / "best"
478
+ if best_idx_step == idx_step:
479
+ if best_dir.exists():
480
+ shutil.rmtree(best_dir)
481
+ shutil.copytree(step_dir, best_dir)
482
+
483
+ # early stop
484
+ early_stop_flag = False
485
+ if best_idx_step == idx_step:
486
+ patience_count = 0
487
+ else:
488
+ patience_count += 1
489
+ if patience_count >= args.patience:
490
+ early_stop_flag = True
491
+
492
+ # early stop
493
+ if early_stop_flag:
494
+ break
495
+ return
496
+
497
+
498
+ if __name__ == "__main__":
499
+ main()
examples/vm_sound_classification8/stop.sh ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ kill -9 `ps -aef | grep 'vm_sound_classification/bin/python3' | grep -v grep | awk '{print $2}' | sed 's/\n/ /'`
install.sh ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # bash install.sh --stage 2 --stop_stage 2 --system_version centos
4
+
5
+
6
+ python_version=3.8.10
7
+ system_version="centos";
8
+
9
+ verbose=true;
10
+ stage=-1
11
+ stop_stage=0
12
+
13
+
14
+ # parse options
15
+ while true; do
16
+ [ -z "${1:-}" ] && break; # break if there are no arguments
17
+ case "$1" in
18
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
19
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
20
+ old_value="(eval echo \\$$name)";
21
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
22
+ was_bool=true;
23
+ else
24
+ was_bool=false;
25
+ fi
26
+
27
+ # Set the variable to the right value-- the escaped quotes make it work if
28
+ # the option had spaces, like --cmd "queue.pl -sync y"
29
+ eval "${name}=\"$2\"";
30
+
31
+ # Check that Boolean-valued arguments are really Boolean.
32
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
33
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
34
+ exit 1;
35
+ fi
36
+ shift 2;
37
+ ;;
38
+
39
+ *) break;
40
+ esac
41
+ done
42
+
43
+ work_dir="$(pwd)"
44
+
45
+
46
+ if [ ${stage} -le 1 ] && [ ${stop_stage} -ge 1 ]; then
47
+ $verbose && echo "stage 1: install python"
48
+ cd "${work_dir}" || exit 1;
49
+
50
+ sh ./script/install_python.sh --python_version "${python_version}" --system_version "${system_version}"
51
+ fi
52
+
53
+
54
+ if [ ${stage} -le 2 ] && [ ${stop_stage} -ge 2 ]; then
55
+ $verbose && echo "stage 2: create virtualenv"
56
+
57
+ # /usr/local/python-3.6.5/bin/virtualenv vm_sound_classification
58
+ # source /data/local/bin/vm_sound_classification/bin/activate
59
+ /usr/local/python-${python_version}/bin/pip3 install virtualenv
60
+ mkdir -p /data/local/bin
61
+ cd /data/local/bin || exit 1;
62
+ /usr/local/python-${python_version}/bin/virtualenv vm_sound_classification
63
+
64
+ fi
main.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import argparse
4
+ from functools import lru_cache
5
+ from pathlib import Path
6
+ import platform
7
+ import shutil
8
+ import tempfile
9
+ import zipfile
10
+ from typing import Tuple
11
+
12
+ import gradio as gr
13
+ from huggingface_hub import snapshot_download
14
+ import numpy as np
15
+ import torch
16
+
17
+ from project_settings import environment, project_path
18
+ from toolbox.torch.utils.data.vocabulary import Vocabulary
19
+
20
+
21
+ def get_args():
22
+ parser = argparse.ArgumentParser()
23
+ parser.add_argument(
24
+ "--examples_dir",
25
+ # default=(project_path / "data").as_posix(),
26
+ default=(project_path / "data/examples").as_posix(),
27
+ type=str
28
+ )
29
+ parser.add_argument(
30
+ "--models_repo_id",
31
+ default="qgyd2021/vm_sound_classification",
32
+ type=str
33
+ )
34
+ parser.add_argument(
35
+ "--trained_model_dir",
36
+ default=(project_path / "trained_models").as_posix(),
37
+ type=str
38
+ )
39
+ parser.add_argument(
40
+ "--hf_token",
41
+ default=environment.get("hf_token"),
42
+ type=str,
43
+ )
44
+ parser.add_argument(
45
+ "--server_port",
46
+ default=environment.get("server_port", 7860),
47
+ type=int
48
+ )
49
+
50
+ args = parser.parse_args()
51
+ return args
52
+
53
+
54
+ @lru_cache(maxsize=100)
55
+ def load_model(model_file: Path):
56
+ with zipfile.ZipFile(model_file, "r") as f_zip:
57
+ out_root = Path(tempfile.gettempdir()) / "vm_sound_classification"
58
+ if out_root.exists():
59
+ shutil.rmtree(out_root.as_posix())
60
+ out_root.mkdir(parents=True, exist_ok=True)
61
+ f_zip.extractall(path=out_root)
62
+
63
+ tgt_path = out_root / model_file.stem
64
+ jit_model_file = tgt_path / "trace_model.zip"
65
+ vocab_path = tgt_path / "vocabulary"
66
+
67
+ vocabulary = Vocabulary.from_files(vocab_path.as_posix())
68
+
69
+ with open(jit_model_file.as_posix(), "rb") as f:
70
+ model = torch.jit.load(f)
71
+ model.eval()
72
+
73
+ shutil.rmtree(tgt_path)
74
+
75
+ d = {
76
+ "model": model,
77
+ "vocabulary": vocabulary
78
+ }
79
+ return d
80
+
81
+
82
+ def click_button(audio: np.ndarray,
83
+ model_name: str,
84
+ ground_true: str) -> Tuple[str, float]:
85
+
86
+ sample_rate, signal = audio
87
+
88
+ model_file = "trained_models/{}.zip".format(model_name)
89
+ model_file = Path(model_file)
90
+ d = load_model(model_file)
91
+
92
+ model = d["model"]
93
+ vocabulary = d["vocabulary"]
94
+
95
+ inputs = signal / (1 << 15)
96
+ inputs = torch.tensor(inputs, dtype=torch.float32)
97
+ inputs = torch.unsqueeze(inputs, dim=0)
98
+
99
+ with torch.no_grad():
100
+ logits = model.forward(inputs)
101
+ probs = torch.nn.functional.softmax(logits, dim=-1)
102
+ label_idx = torch.argmax(probs, dim=-1)
103
+
104
+ label_idx = label_idx.cpu()
105
+ probs = probs.cpu()
106
+
107
+ label_idx = label_idx.numpy()[0]
108
+ prob = probs.numpy()[0][label_idx]
109
+
110
+ label_str = vocabulary.get_token_from_index(label_idx, namespace="labels")
111
+
112
+ return label_str, round(prob, 4)
113
+
114
+
115
+ def main():
116
+ args = get_args()
117
+
118
+ examples_dir = Path(args.examples_dir)
119
+ trained_model_dir = Path(args.trained_model_dir)
120
+
121
+ # download models
122
+ if not trained_model_dir.exists():
123
+ trained_model_dir.mkdir(parents=True, exist_ok=True)
124
+ _ = snapshot_download(
125
+ repo_id=args.models_repo_id,
126
+ local_dir=trained_model_dir.as_posix(),
127
+ token=args.hf_token,
128
+ )
129
+
130
+ # examples
131
+ example_zip_file = trained_model_dir / "examples.zip"
132
+ with zipfile.ZipFile(example_zip_file.as_posix(), "r") as f_zip:
133
+ out_root = examples_dir
134
+ if out_root.exists():
135
+ shutil.rmtree(out_root.as_posix())
136
+ out_root.mkdir(parents=True, exist_ok=True)
137
+ f_zip.extractall(path=out_root)
138
+
139
+ # models
140
+ model_choices = list()
141
+ for filename in trained_model_dir.glob("*.zip"):
142
+ model_name = filename.stem
143
+ if model_name == "examples":
144
+ continue
145
+ model_choices.append(model_name)
146
+ model_choices = list(sorted(model_choices))
147
+
148
+ # examples
149
+ examples = list()
150
+ for filename in examples_dir.glob("**/*/*.wav"):
151
+ label = filename.parts[-2]
152
+
153
+ examples.append([
154
+ filename.as_posix(),
155
+ model_choices[0],
156
+ label
157
+ ])
158
+
159
+ # ui
160
+ brief_description = """
161
+ 国际语音智能外呼系统, 电话声音分类, 8000, int16.
162
+ """
163
+
164
+ # ui
165
+ with gr.Blocks() as blocks:
166
+ gr.Markdown(value=brief_description)
167
+
168
+ with gr.Row():
169
+ with gr.Column(scale=3):
170
+ c_audio = gr.Audio(label="audio")
171
+ with gr.Row():
172
+ with gr.Column(scale=3):
173
+ c_model_name = gr.Dropdown(choices=model_choices, value=model_choices[0], label="model_name")
174
+ with gr.Column(scale=3):
175
+ c_ground_true = gr.Textbox(label="ground_true")
176
+
177
+ c_button = gr.Button("run", variant="primary")
178
+ with gr.Column(scale=3):
179
+ c_label = gr.Textbox(label="label")
180
+ c_probability = gr.Number(label="probability")
181
+
182
+ gr.Examples(
183
+ examples,
184
+ inputs=[c_audio, c_model_name, c_ground_true],
185
+ outputs=[c_label, c_probability],
186
+ fn=click_button,
187
+ examples_per_page=5,
188
+ )
189
+
190
+ c_button.click(
191
+ click_button,
192
+ inputs=[c_audio, c_model_name, c_ground_true],
193
+ outputs=[c_label, c_probability],
194
+ )
195
+
196
+ # http://127.0.0.1:7864/
197
+ blocks.queue().launch(
198
+ share=False if platform.system() == "Windows" else False,
199
+ server_name="127.0.0.1" if platform.system() == "Windows" else "0.0.0.0",
200
+ server_port=args.server_port
201
+ )
202
+ return
203
+
204
+
205
+ if __name__ == "__main__":
206
+ main()
project_settings.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/python3
2
+ # -*- coding: utf-8 -*-
3
+ import os
4
+ from pathlib import Path
5
+
6
+ from toolbox.os.environment import EnvironmentManager
7
+
8
+
9
+ project_path = os.path.abspath(os.path.dirname(__file__))
10
+ project_path = Path(project_path)
11
+
12
+ environment = EnvironmentManager(
13
+ path=os.path.join(project_path, "dotenv"),
14
+ env=os.environ.get("environment", "dev"),
15
+ )
16
+
17
+
18
+ if __name__ == '__main__':
19
+ pass
requirements.txt ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.3.0
2
+ torchaudio==2.3.0
3
+ fsspec==2024.5.0
4
+ librosa==0.10.2
5
+ pandas==2.0.3
6
+ openpyxl==3.0.9
7
+ xlrd==1.2.0
8
+ tqdm==4.66.4
9
+ overrides==1.9.0
10
+ pyyaml==6.0.1
11
+ evaluate==0.4.2
12
+ gradio
13
+ python-dotenv==1.0.1
script/install_nvidia_driver.sh ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+ #GPU驱动安装需要先将原有的显示关闭, 重启机器, 再进行安装.
3
+ #参考链接:
4
+ #https://blog.csdn.net/kingschan/article/details/19033595
5
+ #https://blog.csdn.net/HaixWang/article/details/90408538
6
+ #
7
+ #>>> yum install -y pciutils
8
+ #查看 linux 机器上是否有 GPU
9
+ #lspci |grep -i nvidia
10
+ #
11
+ #>>> lspci |grep -i nvidia
12
+ #00:08.0 3D controller: NVIDIA Corporation TU104GL [Tesla T4] (rev a1)
13
+ #
14
+ #
15
+ #NVIDIA 驱动程序下载
16
+ #先在 pytorch 上查看应该用什么 cuda 版本, 再安装对应的 cuda-toolkit cuda.
17
+ #再根据 gpu 版本下载安装对应的 nvidia 驱动
18
+ #
19
+ ## pytorch 版本
20
+ #https://pytorch.org/get-started/locally/
21
+ #
22
+ ## CUDA 下载 (好像不需要这个)
23
+ #https://developer.nvidia.com/cuda-toolkit-archive
24
+ #
25
+ ## nvidia 驱动
26
+ #https://www.nvidia.cn/Download/index.aspx?lang=cn
27
+ #http://www.nvidia.com/Download/index.aspx
28
+ #
29
+ #在下方的下拉列表中进行选择,针对您的 NVIDIA 产品确定合适的驱动。
30
+ #产品类型:
31
+ #Data Center / Tesla
32
+ #产品系列:
33
+ #T-Series
34
+ #产品家族:
35
+ #Tesla T4
36
+ #操作系统:
37
+ #Linux 64-bit
38
+ #CUDA Toolkit:
39
+ #10.2
40
+ #语言:
41
+ #Chinese (Simpleified)
42
+ #
43
+ #
44
+ #>>> mkdir -p /data/tianxing
45
+ #>>> cd /data/tianxing
46
+ #>>> wget https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
47
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
48
+ #
49
+ ## 异常:
50
+ #ERROR: The Nouveau kernel driver is currently in use by your system. This driver is incompatible with the NVIDIA driver, and must be disabled before proceeding. Please consult the NVIDIA driver README and your
51
+ #Linux distribution's documentation for details on how to correctly disable the Nouveau kernel driver.
52
+ #[OK]
53
+ #
54
+ #For some distributions, Nouveau can be disabled by adding a file in the modprobe configuration directory. Would you like nvidia-installer to attempt to create this modprobe file for you?
55
+ #[NO]
56
+ #
57
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
58
+ #page at www.nvidia.com.
59
+ #[OK]
60
+ #
61
+ ## 参考链接:
62
+ #https://blog.csdn.net/kingschan/article/details/19033595
63
+ #
64
+ ## 禁用原有的显卡驱动 nouveau
65
+ #>>> echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
66
+ #>>> sudo dracut --force
67
+ ## 重启
68
+ #>>> reboot
69
+ #
70
+ #>>> init 3
71
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
72
+ #
73
+ ## 异常
74
+ #ERROR: Unable to find the kernel source tree for the currently running kernel. Please make sure you have installed the kernel source files for your kernel and that they are properly configured; on Red Hat Linux systems, for example, be sure you have the 'kernel-source' or 'kernel-devel' RPM installed. If you know the correct kernel source files are installed, you may specify the kernel source path with the '--kernel-source-path' command line option.
75
+ #[OK]
76
+ #ERROR: Installation has failed. Please see the file '/var/log/nvidia-installer.log' for details. You may find suggestions on fixing installation problems in the README available on the Linux driver download
77
+ #page at www.nvidia.com.
78
+ #[OK]
79
+ #
80
+ ## 参考链接
81
+ ## https://blog.csdn.net/HaixWang/article/details/90408538
82
+ #
83
+ #>>> uname -r
84
+ #3.10.0-1160.49.1.el7.x86_64
85
+ #>>> yum install kernel-devel kernel-headers -y
86
+ #>>> yum info kernel-devel kernel-headers
87
+ #>>> yum install -y "kernel-devel-uname-r == $(uname -r)"
88
+ #>>> yum -y distro-sync
89
+ #
90
+ #>>> sh NVIDIA-Linux-x86_64-440.118.02.run
91
+ #
92
+ ## 安装成功
93
+ #WARNING: nvidia-installer was forced to guess the X library path '/usr/lib64' and X module path '/usr/lib64/xorg/modules'; these paths were not queryable from the system. If X fails to find the NVIDIA X driver
94
+ #module, please install the `pkg-config` utility and the X.Org SDK/development package for your distribution and reinstall the driver.
95
+ #[OK]
96
+ #Install NVIDIA's 32-bit compatibility libraries?
97
+ #[YES]
98
+ #Installation of the kernel module for the NVIDIA Accelerated Graphics Driver for Linux-x86_64 (version 440.118.02) is now complete.
99
+ #[OK]
100
+ #
101
+ #
102
+ ## 查看 GPU 使用情况; watch -n 1 -d nvidia-smi 每1秒刷新一次.
103
+ #>>> nvidia-smi
104
+ #Thu Mar 9 12:00:37 2023
105
+ #+-----------------------------------------------------------------------------+
106
+ #| NVIDIA-SMI 440.118.02 Driver Version: 440.118.02 CUDA Version: 10.2 |
107
+ #|-------------------------------+----------------------+----------------------+
108
+ #| GPU Name Persistence-M| Bus-Id Disp.A | Volatile Uncorr. ECC |
109
+ #| Fan Temp Perf Pwr:Usage/Cap| Memory-Usage | GPU-Util Compute M. |
110
+ #|===============================+======================+======================|
111
+ #| 0 Tesla T4 Off | 00000000:00:08.0 Off | Off |
112
+ #| N/A 54C P0 22W / 70W | 0MiB / 16127MiB | 0% Default |
113
+ #+-------------------------------+----------------------+----------------------+
114
+ #
115
+ #+-----------------------------------------------------------------------------+
116
+ #| Processes: GPU Memory |
117
+ #| GPU PID Type Process name Usage |
118
+ #|=============================================================================|
119
+ #| No running processes found |
120
+ #+-----------------------------------------------------------------------------+
121
+ #
122
+ #
123
+
124
+ # params
125
+ stage=1
126
+ nvidia_driver_filename=https://cn.download.nvidia.com/tesla/440.118.02/NVIDIA-Linux-x86_64-440.118.02.run
127
+
128
+ # parse options
129
+ while true; do
130
+ [ -z "${1:-}" ] && break; # break if there are no arguments
131
+ case "$1" in
132
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
133
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
134
+ old_value="(eval echo \\$$name)";
135
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
136
+ was_bool=true;
137
+ else
138
+ was_bool=false;
139
+ fi
140
+
141
+ # Set the variable to the right value-- the escaped quotes make it work if
142
+ # the option had spaces, like --cmd "queue.pl -sync y"
143
+ eval "${name}=\"$2\"";
144
+
145
+ # Check that Boolean-valued arguments are really Boolean.
146
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
147
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
148
+ exit 1;
149
+ fi
150
+ shift 2;
151
+ ;;
152
+
153
+ *) break;
154
+ esac
155
+ done
156
+
157
+ echo "stage: ${stage}";
158
+
159
+ yum -y install wget
160
+ yum -y install sudo
161
+
162
+ if [ ${stage} -eq 0 ]; then
163
+ mkdir -p /data/dep
164
+ cd /data/dep || echo 1;
165
+ wget -P /data/dep ${nvidia_driver_filename}
166
+
167
+ echo -e "blacklist nouveau\noptions nouveau modeset=0\n" > /etc/modprobe.d/blacklist-nouveau.conf
168
+ sudo dracut --force
169
+ # 重启
170
+ reboot
171
+ elif [ ${stage} -eq 1 ]; then
172
+ init 3
173
+
174
+ yum install -y kernel-devel kernel-headers
175
+ yum info kernel-devel kernel-headers
176
+ yum install -y "kernel-devel-uname-r == $(uname -r)"
177
+ yum -y distro-sync
178
+
179
+ cd /data/dep || echo 1;
180
+
181
+ # 安装时, 需要回车三下.
182
+ sh NVIDIA-Linux-x86_64-440.118.02.run
183
+ nvidia-smi
184
+ fi
script/install_python.sh ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env bash
2
+
3
+ # 参数:
4
+ python_version="3.6.5";
5
+ system_version="centos";
6
+
7
+
8
+ # parse options
9
+ while true; do
10
+ [ -z "${1:-}" ] && break; # break if there are no arguments
11
+ case "$1" in
12
+ --*) name=$(echo "$1" | sed s/^--// | sed s/-/_/g);
13
+ eval '[ -z "${'"$name"'+xxx}" ]' && echo "$0: invalid option $1" 1>&2 && exit 1;
14
+ old_value="(eval echo \\$$name)";
15
+ if [ "${old_value}" == "true" ] || [ "${old_value}" == "false" ]; then
16
+ was_bool=true;
17
+ else
18
+ was_bool=false;
19
+ fi
20
+
21
+ # Set the variable to the right value-- the escaped quotes make it work if
22
+ # the option had spaces, like --cmd "queue.pl -sync y"
23
+ eval "${name}=\"$2\"";
24
+
25
+ # Check that Boolean-valued arguments are really Boolean.
26
+ if $was_bool && [[ "$2" != "true" && "$2" != "false" ]]; then
27
+ echo "$0: expected \"true\" or \"false\": $1 $2" 1>&2
28
+ exit 1;
29
+ fi
30
+ shift 2;
31
+ ;;
32
+
33
+ *) break;
34
+ esac
35
+ done
36
+
37
+ echo "python_version: ${python_version}";
38
+ echo "system_version: ${system_version}";
39
+
40
+
41
+ if [ ${system_version} = "centos" ]; then
42
+ # 安装 python 开发编译环境
43
+ yum -y groupinstall "Development tools"
44
+ yum -y install zlib-devel bzip2-devel openssl-devel ncurses-devel sqlite-devel readline-devel tk-devel gdbm-devel db4-devel libpcap-devel xz-devel
45
+ yum install libffi-devel -y
46
+ yum install -y wget
47
+ yum install -y make
48
+
49
+ mkdir -p /data/dep
50
+ cd /data/dep || exit 1;
51
+ if [ ! -e Python-${python_version}.tgz ]; then
52
+ wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
53
+ fi
54
+
55
+ cd /data/dep || exit 1;
56
+ if [ ! -d Python-${python_version} ]; then
57
+ tar -zxvf Python-${python_version}.tgz
58
+ cd /data/dep/Python-${python_version} || exit 1;
59
+ fi
60
+
61
+ mkdir /usr/local/python-${python_version}
62
+ ./configure --prefix=/usr/local/python-${python_version}
63
+ make && make install
64
+
65
+ /usr/local/python-${python_version}/bin/python3 -V
66
+ /usr/local/python-${python_version}/bin/pip3 -V
67
+
68
+ rm -rf /usr/local/bin/python3
69
+ rm -rf /usr/local/bin/pip3
70
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
71
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
72
+
73
+ python3 -V
74
+ pip3 -V
75
+
76
+ elif [ ${system_version} = "ubuntu" ]; then
77
+ # 安装 python 开发编译环境
78
+ # https://zhuanlan.zhihu.com/p/506491209
79
+
80
+ # 刷新软件包目录
81
+ sudo apt update
82
+ # 列出当前可用的更新
83
+ sudo apt list --upgradable
84
+ # 如上一步提示有可以更新的项目,则执行更新
85
+ sudo apt -y upgrade
86
+ # 安装 GCC 编译器
87
+ sudo apt install gcc
88
+ # 检查安装是否成功
89
+ gcc -v
90
+
91
+ # 安装依赖
92
+ sudo apt install -y build-essential zlib1g-dev libncurses5-dev libgdbm-dev libnss3-dev libssl-dev libreadline-dev libffi-dev libbz2-dev liblzma-dev sqlite3 libsqlite3-dev tk-dev uuid-dev libgdbm-compat-dev
93
+
94
+ mkdir -p /data/dep
95
+ cd /data/dep || exit 1;
96
+ if [ ! -e Python-${python_version}.tgz ]; then
97
+ # sudo wget -P /data/dep https://www.python.org/ftp/python/3.6.5/Python-3.6.5.tgz
98
+ sudo wget -P /data/dep https://www.python.org/ftp/python/${python_version}/Python-${python_version}.tgz
99
+ fi
100
+
101
+ cd /data/dep || exit 1;
102
+ if [ ! -d Python-${python_version} ]; then
103
+ # tar -zxvf Python-3.6.5.tgz
104
+ tar -zxvf Python-${python_version}.tgz
105
+ # cd /data/dep/Python-3.6.5
106
+ cd /data/dep/Python-${python_version} || exit 1;
107
+ fi
108
+
109
+ # mkdir /usr/local/python-3.6.5
110
+ mkdir /usr/local/python-${python_version}
111
+
112
+ # 检查依赖与配置编译
113
+ # sudo ./configure --prefix=/usr/local/python-3.6.5 --enable-optimizations --with-lto --enable-shared
114
+ sudo ./configure --prefix=/usr/local/python-${python_version} --enable-optimizations --with-lto --enable-shared
115
+ cpu_count=$(cat /proc/cpuinfo | grep processor | wc -l)
116
+ # sudo make -j 4
117
+ sudo make -j "${cpu_count}"
118
+
119
+ /usr/local/python-${python_version}/bin/python3 -V
120
+ /usr/local/python-${python_version}/bin/pip3 -V
121
+
122
+ rm -rf /usr/local/bin/python3
123
+ rm -rf /usr/local/bin/pip3
124
+ ln -s /usr/local/python-${python_version}/bin/python3 /usr/local/bin/python3
125
+ ln -s /usr/local/python-${python_version}/bin/pip3 /usr/local/bin/pip3
126
+
127
+ python3 -V
128
+ pip3 -V
129
+ fi