import pandas as pd import os import random import ast import numpy as np import torch from einops import repeat, rearrange import librosa from torch.utils.data import Dataset import torchaudio def log_f0(f0, f0_min=librosa.note_to_hz('C2'), scales=4): f0[f0 < f0_min] = 0.0 f0_log = torch.zeros_like(f0) f0_log[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1 # f0_mel_min = 12*np.log2(f0_min/f0_min) + 1 # f0_mel_max = 12*np.log2(f0_max/f0_min) + 1 f0_log /= (scales*12) return f0_log class VCData(Dataset): def __init__(self, data_dir, meta_dir, subset, prompt_dir, seg_length=1.92, speaker_length=4, sr=24000, content_sr=50, speaker_sr=16000, plugin_mode=False ): self.datadir = data_dir meta = pd.read_csv(meta_dir) self.meta = meta[meta['subset'] == subset] self.subset = subset self.prompts = pd.read_csv(prompt_dir) self.seg_len = seg_length self.speaker_length = speaker_length self.sr = sr self.content_sr = content_sr self.speaker_sr = speaker_sr self.plugin_mode = plugin_mode def get_audio_content(self, audio_path, content_path, f0_path): audio_path = self.datadir + audio_path audio, sr = torchaudio.load(audio_path) assert sr == self.sr # 1, T, C content = torch.load(self.datadir + content_path) total_length = content.shape[1] if int(total_length - int(self.content_sr * self.seg_len)) > 0: start = np.random.randint(0, int(total_length - self.content_sr * self.seg_len) + 1) else: start = 0 end = min(start + int(self.seg_len * self.content_sr), content.shape[1]) # use last frame for padding content_clip = repeat(content[:, -1, :], "b c-> b t c", t=int(self.content_sr * self.seg_len)).clone() content_clip[:, :end - start, :] = content[:, start: end, :] audio_clip = torch.zeros(int(self.seg_len * self.sr)) # print(start) # print(end) audio_start = round(start * self.sr / self.content_sr) audio_end = round(end * self.sr / self.content_sr) # print(audio_start) # print(audio_end) # print(audio.shape) audio_clip[:audio_end - audio_start] = audio[0, audio_start: audio_end].clone() if f0_path: f0 = torch.load(self.datadir + f0_path).float() f0_clip = torch.zeros(int(self.content_sr * self.seg_len)) f0_clip[:end-start] = f0[start:end] f0_clip = log_f0(f0_clip) f0_clip = f0_clip.unsqueeze(-1) else: f0_clip = None return audio_clip, content_clip[0], f0_clip def get_speaker(self, speaker_path): audio_path = self.datadir + speaker_path audio, sr = torchaudio.load(audio_path) assert sr == self.speaker_sr # if sr != self.speaker_sr: # resampler = torchaudio.transforms.Resample(sr, self.speaker_sr, dtype=audio.dtype) # audio = resampler(audio) audio_clip = torch.zeros(self.speaker_length * self.speaker_sr) total_length = audio.shape[1] if int(total_length - self.speaker_sr * self.speaker_length) > 0: start = np.random.randint(0, int(total_length - self.speaker_sr * self.speaker_length) + 1) else: start = 0 end = min(start + self.speaker_sr * self.speaker_length, total_length) audio_clip[:end-start] = audio[0, start: end] return audio_clip def __getitem__(self, index): row = self.meta.iloc[index] if self.plugin_mode: audio_clip, content_clip, f0_clip = [''], [''], [''] else: # load current audio audio_path = row['audio_path'] content_path = row['content_path'] f0_path = row['f0_path'] audio_clip, content_clip, f0_clip = self.get_audio_content(audio_path, content_path, f0_path) # get speaker if self.subset == 'train': speaker = row['speaker'] else: speaker = row['speaker_val'] speaker_row = self.meta[self.meta['speaker'] == speaker].sample(1) speaker_path = speaker_row.iloc[0]['speaker_path'] speaker_clip = self.get_speaker(speaker_path) # print(speaker_clip.shape) # print(speaker_path) # print(speaker) # get prompt prompts = self.prompts[self.prompts['ID'] == speaker]['prompts'].iloc[0] prompts = ast.literal_eval(prompts) prompt = random.choice(prompts) return audio_clip, content_clip, f0_clip, speaker_clip, prompt def __len__(self): return len(self.meta) if __name__ == '__main__': from tqdm import tqdm data = VCData('../../features/', '../../data/meta_val.csv', 'val', '../../data/speaker_gender.csv') for i in tqdm(range(len(data))): x = data[i] # print(x[-1])