File size: 6,408 Bytes
f25cff8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
import os
import sys

sys.path.append(os.getcwd())

import argparse
import csv
import json
import shutil
from importlib.resources import files
from pathlib import Path
from concurrent.futures import ThreadPoolExecutor, as_completed

import torchaudio
from tqdm import tqdm
from datasets.arrow_writer import ArrowWriter

from f5_tts.model.utils import (
    convert_char_to_pinyin,
)


# Increase the field size limit
csv.field_size_limit(sys.maxsize)

# PRETRAINED_VOCAB_PATH = files("f5_tts").joinpath("../../data/Emilia_ZH_EN_pinyin/vocab.txt")
PRETRAINED_VOCAB_PATH = Path("/home/tts/ttsteam/repos/F5-TTS/ckpts/vocab.txt")


def is_csv_wavs_format(input_dataset_dir):

    # import pdb;pdb.set_trace()

    fpath = Path(input_dataset_dir)
    metadata = fpath / "metadata.csv"
    wavs = fpath / "wavs"
    return metadata.exists() and metadata.is_file() and wavs.exists() and wavs.is_dir()


def prepare_csv_wavs_dir(input_dir, num_threads=16):  # Added num_threads parameter
    print("Inside prepare csv wavs dir!")
    # assert is_csv_wavs_format(input_dir), f"not csv_wavs format: {input_dir}"
    input_dir = Path(input_dir)
    metadata_path = input_dir / "metadata.csv"
    audio_path_text_pairs = read_audio_text_pairs(metadata_path.as_posix())

    sub_result, durations = [], []
    vocab_set = set()
    polyphone = True

    def process_audio(audio_path_text):
        audio_path, text = audio_path_text
        if not Path(audio_path).exists():
            print(f"audio {audio_path} not found, skipping")
            return None
        audio_duration = get_audio_duration(audio_path)
        text = convert_char_to_pinyin([text], polyphone=polyphone)[0]
        return {"audio_path": audio_path, "text": text, "duration": audio_duration}, audio_duration

    with ThreadPoolExecutor(max_workers=num_threads) as executor:  # Set max_workers
        futures = {executor.submit(process_audio, pair): pair for pair in audio_path_text_pairs}
        
        # Use tqdm to track progress
        for future in tqdm(as_completed(futures), total=len(futures), desc="Processing audio files"):
            result = future.result()
            if result is not None:
                # print("result is: ", result)
                sub_result.append(result[0])
                durations.append(result[1])
                vocab_set.update(list(result[0]['text']))

    return sub_result, durations, vocab_set


def get_audio_duration(audio_path):
    audio, sample_rate = torchaudio.load(audio_path)
    return audio.shape[1] / sample_rate


def read_audio_text_pairs(csv_file_path):
    audio_text_pairs = []

    parent = Path(csv_file_path).parent
    with open(csv_file_path, mode="r", newline="", encoding="utf-8-sig") as csvfile:
        reader = csv.reader(csvfile, delimiter="|")
        next(reader)  # Skip the header row
        for row in reader:
            if len(row) >= 2:
                audio_file = row[0].strip()  # First column: audio file path
                text = row[1].strip()  # Second column: text
                # audio_file_path = parent / audio_file
                audio_file_path = audio_file
                audio_text_pairs.append((Path(audio_file_path).as_posix(), text))

    return audio_text_pairs


def save_prepped_dataset(out_dir, result, duration_list, text_vocab_set, is_finetune):
    out_dir = Path(out_dir)
    # save preprocessed dataset to disk
    out_dir.mkdir(exist_ok=True, parents=True)
    print(f"\nSaving to {out_dir} ...")

    # dataset = Dataset.from_dict({"audio_path": audio_path_list, "text": text_list, "duration": duration_list})  # oom
    # dataset.save_to_disk(f"{out_dir}/raw", max_shard_size="2GB")
    raw_arrow_path = out_dir / "raw.arrow"
    with ArrowWriter(path=raw_arrow_path.as_posix(), writer_batch_size=1) as writer:
        for line in tqdm(result, desc="Writing to raw.arrow ..."):
            writer.write(line)

    # dup a json separately saving duration in case for DynamicBatchSampler ease
    dur_json_path = out_dir / "duration.json"
    with open(dur_json_path.as_posix(), "w", encoding="utf-8") as f:
        json.dump({"duration": duration_list}, f, ensure_ascii=False)

    # vocab map, i.e. tokenizer
    # add alphabets and symbols (optional, if plan to ft on de/fr etc.)
    # if tokenizer == "pinyin":
    #     text_vocab_set.update([chr(i) for i in range(32, 127)] + [chr(i) for i in range(192, 256)])
    voca_out_path = out_dir / "vocab.txt"
    with open(voca_out_path.as_posix(), "w") as f:
        for vocab in sorted(text_vocab_set):
            f.write(vocab + "\n")

    voca_out_path = out_dir / "new_vocab.txt"
    with open(voca_out_path.as_posix(), "w") as f:
        for vocab in sorted(text_vocab_set):
            f.write(vocab + "\n")

    if is_finetune:
        file_vocab_finetune = PRETRAINED_VOCAB_PATH.as_posix()
        shutil.copy2(file_vocab_finetune, voca_out_path)
    else:
        with open(voca_out_path, "w") as f:
            for vocab in sorted(text_vocab_set):
                f.write(vocab + "\n")

    dataset_name = out_dir.stem
    print(f"\nFor {dataset_name}, sample count: {len(result)}")
    print(f"For {dataset_name}, vocab size is: {len(text_vocab_set)}")
    print(f"For {dataset_name}, total {sum(duration_list)/3600:.2f} hours")


def prepare_and_save_set(inp_dir, out_dir, is_finetune: bool = True):
    if is_finetune:
        print("Inside finetuning ...")
        assert PRETRAINED_VOCAB_PATH.exists(), f"pretrained vocab.txt not found: {PRETRAINED_VOCAB_PATH}"
    sub_result, durations, vocab_set = prepare_csv_wavs_dir(inp_dir)
    save_prepped_dataset(out_dir, sub_result, durations, vocab_set, is_finetune)


def cli():
    # finetune: python scripts/prepare_csv_wavs.py /path/to/input_dir /path/to/output_dir_pinyin
    # pretrain: python scripts/prepare_csv_wavs.py /path/to/output_dir_pinyin --pretrain
    parser = argparse.ArgumentParser(description="Prepare and save dataset.")
    parser.add_argument("inp_dir", type=str, help="Input directory containing the data.")
    parser.add_argument("out_dir", type=str, help="Output directory to save the prepared data.")
    parser.add_argument("--pretrain", action="store_true", help="Enable for new pretrain, otherwise is a fine-tune")

    args = parser.parse_args()

    prepare_and_save_set(args.inp_dir, args.out_dir, is_finetune=not args.pretrain)


if __name__ == "__main__":
    cli()