#!/usr/bin/python3 # -*- coding: utf-8 -*- import argparse from glob import glob import json import os from pathlib import Path import random import sys pwd = os.path.abspath(os.path.dirname(__file__)) sys.path.append(os.path.join(pwd, "../../")) import pandas as pd from scipy.io import wavfile from tqdm import tqdm def get_args(): parser = argparse.ArgumentParser() parser.add_argument("--file_dir", default="./", type=str) parser.add_argument("--task", default="default", type=str) parser.add_argument("--filename_patterns", type=str) parser.add_argument("--train_dataset", default="train.xlsx", type=str) parser.add_argument("--valid_dataset", default="valid.xlsx", type=str) parser.add_argument("--label_plan", default="4", type=str) args = parser.parse_args() return args def get_dataset(args): filename_patterns = args.filename_patterns filename_patterns = filename_patterns.split(" ") print(filename_patterns) file_dir = Path(args.file_dir) file_dir.mkdir(exist_ok=True) if args.label_plan == "2": label_map = { "bell": "non_voice", "white_noise": "non_voice", "low_white_noise": "non_voice", "high_white_noise": "non_voice", "music": "non_voice", "mute": "non_voice", "noise": "non_voice", "noise_mute": "non_voice", "voice": "voice", "voicemail": "voice", } elif args.label_plan == "3": label_map = { "bell": "voicemail", "white_noise": "mute", "low_white_noise": "mute", "high_white_noise": "mute", # "music": "music", "mute": "mute", "noise": "voice_or_noise", "noise_mute": "voice_or_noise", "voice": "voice_or_noise", "voicemail": "voicemail", } elif args.label_plan == "4": label_map = { "bell": "voicemail", "white_noise": "mute", "low_white_noise": "mute", "high_white_noise": "mute", # "music": "music", "mute": "mute", "noise": "noise", "noise_mute": "noise", "voice": "voice", "voicemail": "voicemail", } elif args.label_plan == "8": label_map = { "bell": "bell", "white_noise": "white_noise", "low_white_noise": "white_noise", "high_white_noise": "white_noise", "music": "music", "mute": "mute", "noise": "noise", "noise_mute": "noise_mute", "voice": "voice", "voicemail": "voicemail", } else: raise AssertionError result = list() for filename_pattern in filename_patterns: filename_list = glob(filename_pattern) for filename in tqdm(filename_list): filename = Path(filename) sample_rate, signal = wavfile.read(filename.as_posix()) if len(signal) < sample_rate * 2: continue folder = filename.parts[-2] country = filename.parts[-4] if folder not in label_map.keys(): continue labels = label_map[folder] random1 = random.random() random2 = random.random() result.append({ "filename": filename, "folder": folder, "category": country, "labels": labels, "random1": random1, "random2": random2, "flag": "TRAIN" if random2 < 0.8 else "TEST", }) df = pd.DataFrame(result) pivot_table = pd.pivot_table(df, index=["labels"], values=["filename"], aggfunc="count") print(pivot_table) df = df.sort_values(by=["random1"], ascending=False) df.to_excel( file_dir / "dataset.xlsx", index=False, # encoding="utf_8_sig" ) return def split_dataset(args): """分割训练集, 测试集""" file_dir = Path(args.file_dir) file_dir.mkdir(exist_ok=True) df = pd.read_excel(file_dir / "dataset.xlsx") train = list() test = list() for i, row in df.iterrows(): flag = row["flag"] if flag == "TRAIN": train.append(row) else: test.append(row) train = pd.DataFrame(train) train.to_excel( args.train_dataset, index=False, # encoding="utf_8_sig" ) test = pd.DataFrame(test) test.to_excel( args.valid_dataset, index=False, # encoding="utf_8_sig" ) return def main(): args = get_args() get_dataset(args) split_dataset(args) return if __name__ == "__main__": main()