HoneyTian's picture
update
6e26705
raw
history blame
4.84 kB
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import argparse
from glob import glob
import json
import os
from pathlib import Path
import random
import sys
pwd = os.path.abspath(os.path.dirname(__file__))
sys.path.append(os.path.join(pwd, "../../"))
import pandas as pd
from scipy.io import wavfile
from tqdm import tqdm
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("--file_dir", default="./", type=str)
parser.add_argument("--task", default="default", type=str)
parser.add_argument("--filename_patterns", type=str)
parser.add_argument("--train_dataset", default="train.xlsx", type=str)
parser.add_argument("--valid_dataset", default="valid.xlsx", type=str)
parser.add_argument("--label_plan", default="4", type=str)
args = parser.parse_args()
return args
def get_dataset(args):
filename_patterns = args.filename_patterns
filename_patterns = filename_patterns.split(" ")
print(filename_patterns)
file_dir = Path(args.file_dir)
file_dir.mkdir(exist_ok=True)
if args.label_plan == "2":
label_map = {
"bell": "non_voice",
"white_noise": "non_voice",
"low_white_noise": "non_voice",
"high_white_noise": "non_voice",
"music": "non_voice",
"mute": "non_voice",
"noise": "non_voice",
"noise_mute": "non_voice",
"voice": "voice",
"voicemail": "voice",
}
elif args.label_plan == "3":
label_map = {
"bell": "voicemail",
"white_noise": "mute",
"low_white_noise": "mute",
"high_white_noise": "mute",
# "music": "music",
"mute": "mute",
"noise": "voice_or_noise",
"noise_mute": "voice_or_noise",
"voice": "voice_or_noise",
"voicemail": "voicemail",
}
elif args.label_plan == "4":
label_map = {
"bell": "voicemail",
"white_noise": "mute",
"low_white_noise": "mute",
"high_white_noise": "mute",
# "music": "music",
"mute": "mute",
"noise": "noise",
"noise_mute": "noise",
"voice": "voice",
"voicemail": "voicemail",
}
elif args.label_plan == "8":
label_map = {
"bell": "bell",
"white_noise": "white_noise",
"low_white_noise": "white_noise",
"high_white_noise": "white_noise",
"music": "music",
"mute": "mute",
"noise": "noise",
"noise_mute": "noise_mute",
"voice": "voice",
"voicemail": "voicemail",
}
else:
raise AssertionError
result = list()
for filename_pattern in filename_patterns:
filename_list = glob(filename_pattern)
for filename in tqdm(filename_list):
filename = Path(filename)
sample_rate, signal = wavfile.read(filename.as_posix())
if len(signal) < sample_rate * 2:
continue
folder = filename.parts[-2]
country = filename.parts[-4]
if folder not in label_map.keys():
continue
labels = label_map[folder]
random1 = random.random()
random2 = random.random()
result.append({
"filename": filename,
"folder": folder,
"category": country,
"labels": labels,
"random1": random1,
"random2": random2,
"flag": "TRAIN" if random2 < 0.8 else "TEST",
})
df = pd.DataFrame(result)
pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count")
print(pivot_table)
df = df.sort_values(by=["random1"], ascending=False)
df.to_excel(
file_dir / "dataset.xlsx",
index=False,
# encoding="utf_8_sig"
)
return
def split_dataset(args):
"""分割训练集, 测试集"""
file_dir = Path(args.file_dir)
file_dir.mkdir(exist_ok=True)
df = pd.read_excel(file_dir / "dataset.xlsx")
train = list()
test = list()
for i, row in df.iterrows():
flag = row["flag"]
if flag == "TRAIN":
train.append(row)
else:
test.append(row)
train = pd.DataFrame(train)
train.to_excel(
args.train_dataset,
index=False,
# encoding="utf_8_sig"
)
test = pd.DataFrame(test)
test.to_excel(
args.valid_dataset,
index=False,
# encoding="utf_8_sig"
)
return
def main():
args = get_args()
get_dataset(args)
split_dataset(args)
return
if __name__ == "__main__":
main()