Spaces:
Paused
Paused
import os | |
import numpy as np | |
import torch | |
import librosa | |
import logging | |
import shutil | |
from pkg_resources import resource_filename | |
from accelerate import Accelerator | |
from datasets import load_dataset, DatasetDict, Dataset, Audio | |
from .preprocess import Preprocessor, crop_feats_length | |
from .hubert import HubertFeatureExtractor, HubertModel, load_hubert | |
from .f0 import F0Extractor, RMVPE, load_rmvpe | |
from .constants import * | |
logger = logging.getLogger(__name__) | |
def extract_hubert_features( | |
rows, | |
hfe: HubertFeatureExtractor, | |
hubert: str | HubertModel | None, | |
device: torch.device, | |
): | |
if not hfe.is_loaded(): | |
model = load_hubert(hubert, device) | |
hfe.load(model) | |
feats = [] | |
for row in rows["wav_16k"]: | |
feat = hfe.extract_feature_from(row["array"].astype("float32")) | |
feats.append(feat) | |
return {"hubert_feats": feats} | |
def extract_f0_features( | |
rows, f0e: F0Extractor, rmvpe: str | RMVPE | None, device: torch.device | |
): | |
if not f0e.is_loaded(): | |
model = load_rmvpe(rmvpe, device) | |
f0e.load(model) | |
f0s = [] | |
f0nsfs = [] | |
for row in rows["wav_16k"]: | |
f0nsf, f0 = f0e.extract_f0_from(row["array"].astype("float32")) | |
f0s.append(f0) | |
f0nsfs.append(f0nsf) | |
return {"f0": f0s, "f0nsf": f0nsfs} | |
def feature_postprocess(rows): | |
phones = rows["hubert_feats"] | |
for i, phone in enumerate(phones): | |
phone = np.repeat(phone, 2, axis=0) | |
n_num = min(phone.shape[0], 900) | |
phone = phone[:n_num, :] | |
phones[i] = phone | |
if "f0" in rows: | |
pitch = rows["f0"][i] | |
pitch = pitch[:n_num] | |
pitch = np.array(pitch, dtype=np.float32) | |
rows["f0"][i] = pitch | |
if "f0nsf" in rows: | |
pitchf = rows["f0nsf"][i] | |
pitchf = pitchf[:n_num] | |
rows["f0nsf"][i] = pitchf | |
return rows | |
def calculate_spectrogram( | |
rows, n_fft=N_FFT, hop_length=HOP_LENGTH, win_length=WIN_LENGTH | |
): | |
specs = [] | |
hann_window = np.hanning(win_length) | |
pad_amount = int((win_length - hop_length) / 2) | |
for row in rows["wav_gt"]: | |
stft = librosa.stft( | |
np.pad(row["array"], (pad_amount, pad_amount), mode="reflect"), | |
n_fft=n_fft, | |
hop_length=hop_length, | |
win_length=win_length, | |
window=hann_window, | |
center=False, | |
) | |
specs.append(np.abs(stft) + 1e-6) | |
return {"spec": specs} | |
def fix_length(rows, hop_length=HOP_LENGTH): | |
for i, row in enumerate(rows["spec"]): | |
spec = np.array(row) | |
phone = np.array(rows["hubert_feats"][i]) | |
pitch = np.array(rows["f0"][i]) | |
pitchf = np.array(rows["f0nsf"][i]) | |
wav_gt = np.array(rows["wav_gt"][i]["array"]) | |
spec, phone, pitch, pitchf = crop_feats_length(spec, phone, pitch, pitchf) | |
phone_len = phone.shape[0] | |
wav_gt = wav_gt[: phone_len * hop_length] | |
rows["hubert_feats"][i] = phone | |
rows["f0"][i] = pitch | |
rows["f0nsf"][i] = pitchf | |
rows["spec"][i] = spec | |
rows["wav_gt"][i]["array"] = wav_gt | |
return rows | |
def prepare( | |
dir: str | DatasetDict, | |
sr=SR_48K, | |
hubert: str | HubertModel | None = None, | |
rmvpe: str | RMVPE | None = None, | |
batch_size=1, | |
max_slice_length: float | None = 3.0, | |
accelerator: Accelerator = None, | |
include_mute=True, | |
stage=3, | |
): | |
""" | |
Prepare the dataset for training or evaluation. | |
Args: | |
dir (str | DatasetDict): The directory path or DatasetDict object containing the dataset. | |
sr (int, optional): The target sampling rate. Defaults to SR_48K. | |
hubert (str | HubertModel | None, optional): The Hubert model or its name to use for feature extraction. Defaults to None. | |
rmvpe (str | RMVPE | None, optional): The RMVPE model or its name to use for feature extraction. Defaults to None. | |
batch_size (int, optional): The batch size for processing the dataset. Defaults to 1. | |
accelerator (Accelerator, optional): The accelerator object for distributed training. Defaults to None. | |
include_mute (bool, optional): Whether to include a mute audio file in the directory dataset. Defaults to True. | |
stage (int, optional): The dataset preparation level to perform. Defaults to 3. (Stage 1 and 3 are CPU intensive, Stage 2 is GPU intensive.) | |
Returns: | |
DatasetDict: The prepared dataset. | |
""" | |
if accelerator is None: | |
accelerator = Accelerator() | |
if isinstance(dir, (DatasetDict, Dataset)): | |
ds = dir | |
else: | |
mute_source = resource_filename("zerorvc", "assets/mute/mute48k.wav") | |
mute_dest = os.path.join(dir, "mute.wav") | |
if include_mute and not os.path.exists(mute_dest): | |
logger.info(f"Copying {mute_source} to {mute_dest}") | |
shutil.copy(mute_source, mute_dest) | |
ds: DatasetDict | Dataset = load_dataset("audiofolder", data_dir=dir) | |
for key in ds: | |
ds[key] = ds[key].remove_columns( | |
[col for col in ds[key].column_names if col != "audio"] | |
) | |
ds = ds.cast_column("audio", Audio(sampling_rate=sr)) | |
if stage <= 0: | |
return ds | |
# Stage 1, CPU intensive | |
pp = Preprocessor(sr, max_slice_length) if max_slice_length is not None else None | |
def preprocess(rows): | |
wav_gt = [] | |
wav_16k = [] | |
for row in rows["audio"]: | |
if pp is not None: | |
slices = pp.preprocess_audio(row["array"]) | |
for slice in slices: | |
wav_gt.append({"path": "", "array": slice, "sampling_rate": sr}) | |
slice16k = librosa.resample(slice, orig_sr=sr, target_sr=SR_16K) | |
wav_16k.append( | |
{"path": "", "array": slice16k, "sampling_rate": SR_16K} | |
) | |
else: | |
slice = row["array"] | |
wav_gt.append({"path": "", "array": slice, "sampling_rate": sr}) | |
slice16k = librosa.resample(slice, orig_sr=sr, target_sr=SR_16K) | |
wav_16k.append({"path": "", "array": slice16k, "sampling_rate": SR_16K}) | |
return {"wav_gt": wav_gt, "wav_16k": wav_16k} | |
ds = ds.map( | |
preprocess, batched=True, batch_size=batch_size, remove_columns=["audio"] | |
) | |
ds = ds.cast_column("wav_gt", Audio(sampling_rate=sr)) | |
ds = ds.cast_column("wav_16k", Audio(sampling_rate=SR_16K)) | |
if stage <= 1: | |
return ds | |
# Stage 2, GPU intensive | |
hfe = HubertFeatureExtractor() | |
ds = ds.map( | |
extract_hubert_features, | |
batched=True, | |
batch_size=batch_size, | |
fn_kwargs={"hfe": hfe, "hubert": hubert, "device": accelerator.device}, | |
) | |
f0e = F0Extractor() | |
ds = ds.map( | |
extract_f0_features, | |
batched=True, | |
batch_size=batch_size, | |
fn_kwargs={"f0e": f0e, "rmvpe": rmvpe, "device": accelerator.device}, | |
) | |
if stage <= 2: | |
return ds | |
# Stage 3, CPU intensive | |
ds = ds.map(feature_postprocess, batched=True, batch_size=batch_size) | |
ds = ds.map(calculate_spectrogram, batched=True, batch_size=batch_size) | |
ds = ds.map(fix_length, batched=True, batch_size=batch_size) | |
return ds | |
def show_dataset_pitch_distribution(dataset): | |
import matplotlib.pyplot as plt | |
import seaborn as sns | |
import numpy as np | |
sns.set_theme() | |
pitches = [] | |
for row in dataset["f0"]: | |
pitches.extend([p for p in row if p != 1]) | |
pitches = np.array(pitches) | |
stats = { | |
"mean": np.mean(pitches), | |
"std": np.std(pitches), | |
"min": np.min(pitches), | |
"max": np.max(pitches), | |
"median": np.median(pitches), | |
"q1": np.percentile(pitches, 25), | |
"q3": np.percentile(pitches, 75), | |
} | |
plt.figure(figsize=(10, 6)) | |
sns.histplot(pitches, bins=100) | |
plt.title( | |
f"Pitch Distribution\nMean: {stats['mean']:.1f} ± {stats['std']:.1f}\n" | |
f"Range: [{stats['min']:.1f}, {stats['max']:.1f}]\n" | |
f"Quartiles: [{stats['q1']:.1f}, {stats['median']:.1f}, {stats['q3']:.1f}]" | |
) | |
plt.xlabel("Frequency (Note)") | |
plt.ylabel("Count") | |
plt.show() | |