File size: 4,483 Bytes
f25cff8 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
import os
import sys
import json
import argparse
from pathlib import Path
from multiprocessing import Pool
from datasets.arrow_writer import ArrowWriter
from f5_tts.model.utils import convert_char_to_pinyin
from tqdm import tqdm
sys.path.append(os.getcwd())
# Increase CSV field size limit
import csv
csv.field_size_limit(sys.maxsize)
# def get_audio_duration(audio_path):
# """Use SoX for instant audio duration retrieval"""
# result = os.popen(f"soxi -D {audio_path}").read().strip()
# return float(result) if result else 0
import subprocess
def get_audio_duration(audio_path):
"""Use ffprobe for accurate duration retrieval without header issues."""
try:
result = subprocess.run(
["ffprobe", "-v", "error", "-show_entries", "format=duration", "-of",
"default=noprint_wrappers=1:nokey=1", audio_path],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
text=True
)
return float(result.stdout.strip()) if result.stdout.strip() else 0
except Exception as e:
print(f"Error processing {audio_path}: {e}")
return 0
def read_audio_text_pairs(csv_file_path):
"""Use AWK to quickly process CSV"""
awk_cmd = f"awk -F '|' 'NR > 1 {{ print $1, $2 }}' {csv_file_path}"
output = os.popen(awk_cmd).read().strip().split("\n")
parent = Path(csv_file_path).parent
return [(str(parent / line.split(" ")[0]), " ".join(line.split(" ")[1:])) for line in output if len(line.split(" ")) >= 2]
def process_audio(audio_path_text):
"""Processes an audio file: checks existence, computes duration, and converts text to Pinyin"""
audio_path, text = audio_path_text
if not Path(audio_path).exists():
return None
duration = get_audio_duration(audio_path)
if duration < 0.1 or duration > 30:
return None
text = convert_char_to_pinyin([text], polyphone=True)[0]
return {"audio_path": audio_path, "text": text, "duration": duration}, duration
def prepare_csv_wavs_dir(input_dir, num_processes=32):
"""Parallelized processing of audio-text pairs using multiprocessing"""
input_dir = Path(input_dir)
metadata_path = input_dir / "metadata.csv"
audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())
with Pool(num_processes) as pool:
results = list(tqdm(pool.imap(process_audio, audio_path_text_pairs), total=len(audio_path_text_pairs), desc="Processing audio files"))
sub_result, durations, vocab_set = [], [], set()
for result in results:
if result:
sub_result.append(result[0])
durations.append(result[1])
vocab_set.update(list(result[0]['text']))
return sub_result, durations, vocab_set
def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set):
"""Writes the processed dataset to disk efficiently"""
out_dir = Path(out_dir)
out_dir.mkdir(exist_ok=True, parents=True)
print(f"\nSaving to {out_dir} ...")
raw_arrow_path = out_dir / "raw.arrow"
with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
for line in tqdm(result, desc="Writing to raw.arrow"):
writer.write(line) # Stream data directly to Arrow file
dur_json_path = out_dir / "duration.json"
with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
json.dump({"duration": duration_list}, f, ensure_ascii=False)
voca_out_path = out_dir / "new_vocab.txt"
with open(voca_out_path.as_posix(), "w") as f:
f.writelines(f"{vocab}\n" for vocab in sorted(text_vocab_set))
dataset_name = out_dir.stem
print(f"\nFor {dataset_name}, sample count: {len(result)}")
print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")
def prepare_and_save_set(inp_dir, out_dir):
"""Runs the dataset preparation pipeline"""
sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
save_prepped_dataset(out_dir, sub_result, durations, vocab_set)
def cli():
"""Command-line interface for the script"""
parser = argparse.ArgumentParser(description="Prepare and save dataset.")
parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
args = parser.parse_args()
prepare_and_save_set(args.inp_dir, args.out_dir)
if __name__ == "__main__":
cli()
|