Spaces:
Running
Running
#!/usr/bin/python3 | |
# -*- coding: utf-8 -*- | |
import argparse | |
from glob import glob | |
import json | |
import os | |
from pathlib import Path | |
import random | |
import sys | |
pwd = os.path.abspath(os.path.dirname(__file__)) | |
sys.path.append(os.path.join(pwd, "../../")) | |
import pandas as pd | |
from scipy.io import wavfile | |
from tqdm import tqdm | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--file_dir", default="./", type=str) | |
parser.add_argument("--task", default="default", type=str) | |
parser.add_argument("--filename_patterns", type=str) | |
parser.add_argument("--train_dataset", default="train.xlsx", type=str) | |
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) | |
parser.add_argument("--label_plan", default="4", type=str) | |
args = parser.parse_args() | |
return args | |
def get_dataset(args): | |
filename_patterns = args.filename_patterns | |
filename_patterns = filename_patterns.split(" ") | |
print(filename_patterns) | |
file_dir = Path(args.file_dir) | |
file_dir.mkdir(exist_ok=True) | |
if args.label_plan == "2": | |
label_map = { | |
"bell": "non_voice", | |
"white_noise": "non_voice", | |
"low_white_noise": "non_voice", | |
"high_white_noise": "non_voice", | |
"music": "non_voice", | |
"mute": "non_voice", | |
"noise": "non_voice", | |
"noise_mute": "non_voice", | |
"voice": "voice", | |
"voicemail": "voice", | |
} | |
elif args.label_plan == "3": | |
label_map = { | |
"bell": "voicemail", | |
"white_noise": "mute", | |
"low_white_noise": "mute", | |
"high_white_noise": "mute", | |
# "music": "music", | |
"mute": "mute", | |
"noise": "voice_or_noise", | |
"noise_mute": "voice_or_noise", | |
"voice": "voice_or_noise", | |
"voicemail": "voicemail", | |
} | |
elif args.label_plan == "4": | |
label_map = { | |
"bell": "voicemail", | |
"white_noise": "mute", | |
"low_white_noise": "mute", | |
"high_white_noise": "mute", | |
# "music": "music", | |
"mute": "mute", | |
"noise": "noise", | |
"noise_mute": "noise", | |
"voice": "voice", | |
"voicemail": "voicemail", | |
} | |
elif args.label_plan == "8": | |
label_map = { | |
"bell": "bell", | |
"white_noise": "white_noise", | |
"low_white_noise": "white_noise", | |
"high_white_noise": "white_noise", | |
"music": "music", | |
"mute": "mute", | |
"noise": "noise", | |
"noise_mute": "noise_mute", | |
"voice": "voice", | |
"voicemail": "voicemail", | |
} | |
else: | |
raise AssertionError | |
result = list() | |
for filename_pattern in filename_patterns: | |
filename_list = glob(filename_pattern) | |
for filename in tqdm(filename_list): | |
filename = Path(filename) | |
sample_rate, signal = wavfile.read(filename.as_posix()) | |
if len(signal) < sample_rate * 2: | |
continue | |
folder = filename.parts[-2] | |
country = filename.parts[-4] | |
if folder not in label_map.keys(): | |
continue | |
labels = label_map[folder] | |
random1 = random.random() | |
random2 = random.random() | |
result.append({ | |
"filename": filename, | |
"folder": folder, | |
"category": country, | |
"labels": labels, | |
"random1": random1, | |
"random2": random2, | |
"flag": "TRAIN" if random2 < 0.8 else "TEST", | |
}) | |
df = pd.DataFrame(result) | |
pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count") | |
print(pivot_table) | |
df = df.sort_values(by=["random1"], ascending=False) | |
df.to_excel( | |
file_dir / "dataset.xlsx", | |
index=False, | |
# encoding="utf_8_sig" | |
) | |
return | |
def split_dataset(args): | |
"""分割训练集, 测试集""" | |
file_dir = Path(args.file_dir) | |
file_dir.mkdir(exist_ok=True) | |
df = pd.read_excel(file_dir / "dataset.xlsx") | |
train = list() | |
test = list() | |
for i, row in df.iterrows(): | |
flag = row["flag"] | |
if flag == "TRAIN": | |
train.append(row) | |
else: | |
test.append(row) | |
train = pd.DataFrame(train) | |
train.to_excel( | |
args.train_dataset, | |
index=False, | |
# encoding="utf_8_sig" | |
) | |
test = pd.DataFrame(test) | |
test.to_excel( | |
args.valid_dataset, | |
index=False, | |
# encoding="utf_8_sig" | |
) | |
return | |
def main(): | |
args = get_args() | |
get_dataset(args) | |
split_dataset(args) | |
return | |
if __name__ == "__main__": | |
main() | |