import itertools |
import os |
import re |
from collections import defaultdict |
from functools import partial |
from multiprocessing import Pool |
from pathlib import Path |
import click |
import numpy as np |
from loguru import logger |
from tqdm import tqdm |
from fish_speech.datasets.protos.text_data_pb2 import Semantics, Sentence, TextData |
from fish_speech.datasets.protos.text_data_stream import pack_pb_stream |
from fish_speech.utils.file import load_filelist |
os.environ["MKL_NUM_THREADS"] = "1" |
os.environ["OMP_NUM_THREADS"] = "1" |
def task_generator_folder(root: Path, text_extension: str): |
files = list(tqdm(Path(root).rglob("*.npy"), desc=f"Loading {root}")) |
files = sorted(files) |
grouped_files = defaultdict(list) |
for file in tqdm(files, desc=f"Grouping {root}"): |
p = str(file.parent) |
speaker = file.parent.name |
try: |
if isinstance(text_extension, str): |
texts = [file.with_suffix(text_extension).read_text(encoding="utf-8")] |
else: |
texts = [ |
file.with_suffix(ext).read_text(encoding="utf-8") |
for ext in text_extension |
] |
except Exception as e: |
logger.error(f"Failed to read text {file}: {e}") |
continue |
grouped_files[p].append((speaker, file, texts)) |
logger.info( |
f"Found {len(grouped_files)} groups in {root}, {list(grouped_files.keys())[:5]}..." |
) |
for i in grouped_files.values(): |
subset = [(f, t) for _, f, t in i] |
yield i[0][0], subset, "folder" |
def task_generator_filelist(filelist): |
grouped_files = defaultdict(list) |
for filename, speaker, _, text in load_filelist(filelist): |
grouped_files[speaker].append((Path(filename), [text])) |
logger.info(f"Found {len(grouped_files)} groups in {filelist}") |
for speaker, values in grouped_files.items(): |
yield speaker, values, "filelist" |
def run_task(task): |
name, subset, source = task |
sentences = [] |
for file, texts in subset: |
np_file = file.with_suffix(".npy") |
if np_file.exists() is False: |
logger.warning(f"Can't find {np_file}") |
continue |
new_texts = [] |
for text in texts: |
text = re.sub(r"\{.*?\}", " ", text) |
text = re.sub(r"<.*?>", " ", text) |
text = re.sub(r"\s+", " ", text) |
new_texts.append(text) |
try: |
semantics = np.load(np_file) |
except Exception as e: |
logger.error(f"Failed to parse {file}: {e}") |
continue |
if isinstance(semantics, np.ndarray): |
semantics = semantics.tolist() |
sentences.append( |
Sentence( |
texts=new_texts, |
semantics=[Semantics(values=s) for s in semantics], |
) |
) |
return pack_pb_stream( |
TextData( |
source=source, |
name=name, |
sentences=sentences, |
) |
) |
@click.command() |
@click.option( |
"--input", |
type=click.Path(path_type=Path), |
required=True, |
help="A folder containing the dataset or a filelist", |
multiple=True, |
) |
@click.option( |
"--output", type=click.Path(path_type=Path), default="data/quantized-dataset-ft" |
) |
@click.option("--num-workers", type=int, default=16) |
@click.option("--text-extension", type=str, default=[".txt"], multiple=True) |
@click.option( |
"--shard-size", type=int, default=10, help="The maximum size of each shard in mb" |
) |
def main(input, output, num_workers, text_extension, shard_size): |
generator_fns = [] |
for f in input: |
assert f.exists(), f"{f} not found" |
if f.is_dir(): |
generator_fn = task_generator_folder(f, text_extension) |
else: |
generator_fn = task_generator_filelist(f) |
generator_fns.append(generator_fn) |
generator_fn = itertools.chain(*generator_fns) |
output.mkdir(parents=True, exist_ok=True) |
dataset_fp = None |
tar_idx = 0 |
written_size = 0 |
with Pool(num_workers) as p: |
for result in tqdm(p.imap_unordered(run_task, generator_fn)): |
if dataset_fp is None: |
dataset_fp = open(Path(output) / f"{tar_idx:08d}.protos", "wb") |
dataset_fp.write(result) |
written_size += len(result) |
if written_size > shard_size * 1024 * 1024: |
logger.info(f"Finished writing {tar_idx} shards to {output}") |
dataset_fp.close() |
dataset_fp = None |
written_size = 0 |
tar_idx += 1 |
if dataset_fp is not None: |
dataset_fp.close() |
logger.info(f"Finished writing {tar_idx + 1} shards to {output}") |
if __name__ == "__main__": |
main() |