File size: 5,197 Bytes
0dabde8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import pandas as pd
import os
import random
import ast
import numpy as np
import torch
from einops import repeat, rearrange
import librosa

from torch.utils.data import Dataset
import torchaudio


def log_f0(f0, f0_min=librosa.note_to_hz('C2'), scales=4):
    f0[f0 < f0_min] = 0.0
    f0_log = torch.zeros_like(f0)
    f0_log[f0 != 0] = 12*np.log2(f0[f0 != 0]/f0_min) + 1
    # f0_mel_min = 12*np.log2(f0_min/f0_min) + 1
    # f0_mel_max = 12*np.log2(f0_max/f0_min) + 1
    f0_log /= (scales*12)
    return f0_log


class VCData(Dataset):
    def __init__(self,

                 data_dir, meta_dir, subset, prompt_dir,

                 seg_length=1.92, speaker_length=4,

                 sr=24000, content_sr=50, speaker_sr=16000,

                 plugin_mode=False

                 ):
        self.datadir = data_dir
        meta = pd.read_csv(meta_dir)
        self.meta = meta[meta['subset'] == subset]
        self.subset = subset
        self.prompts = pd.read_csv(prompt_dir)
        self.seg_len = seg_length
        self.speaker_length = speaker_length
        self.sr = sr
        self.content_sr = content_sr
        self.speaker_sr = speaker_sr
        self.plugin_mode = plugin_mode

    def get_audio_content(self, audio_path, content_path, f0_path):
        audio_path = self.datadir + audio_path
        audio, sr = torchaudio.load(audio_path)
        assert sr == self.sr

        # 1, T, C
        content = torch.load(self.datadir + content_path)

        total_length = content.shape[1]
        if int(total_length - int(self.content_sr * self.seg_len)) > 0:
            start = np.random.randint(0, int(total_length - self.content_sr * self.seg_len) + 1)
        else:
            start = 0
        end = min(start + int(self.seg_len * self.content_sr), content.shape[1])

        # use last frame for padding
        content_clip = repeat(content[:, -1, :], "b c-> b t c", t=int(self.content_sr * self.seg_len)).clone()
        content_clip[:, :end - start, :] = content[:, start: end, :]

        audio_clip = torch.zeros(int(self.seg_len * self.sr))
        # print(start)
        # print(end)
        audio_start = round(start * self.sr / self.content_sr)
        audio_end = round(end * self.sr / self.content_sr)
        # print(audio_start)
        # print(audio_end)
        # print(audio.shape)

        audio_clip[:audio_end - audio_start] = audio[0, audio_start: audio_end].clone()

        if f0_path:
            f0 = torch.load(self.datadir + f0_path).float()
            f0_clip = torch.zeros(int(self.content_sr * self.seg_len))
            f0_clip[:end-start] = f0[start:end]
            f0_clip = log_f0(f0_clip)
            f0_clip = f0_clip.unsqueeze(-1)
        else:
            f0_clip = None

        return audio_clip, content_clip[0], f0_clip

    def get_speaker(self, speaker_path):
        audio_path = self.datadir + speaker_path
        audio, sr = torchaudio.load(audio_path)
        assert sr == self.speaker_sr
        # if sr != self.speaker_sr:
        #     resampler = torchaudio.transforms.Resample(sr, self.speaker_sr, dtype=audio.dtype)
        #     audio = resampler(audio)

        audio_clip = torch.zeros(self.speaker_length * self.speaker_sr)

        total_length = audio.shape[1]
        if int(total_length - self.speaker_sr * self.speaker_length) > 0:
            start = np.random.randint(0, int(total_length - self.speaker_sr * self.speaker_length) + 1)
        else:
            start = 0
        end = min(start + self.speaker_sr * self.speaker_length, total_length)

        audio_clip[:end-start] = audio[0, start: end]

        return audio_clip

    def __getitem__(self, index):
        row = self.meta.iloc[index]

        if self.plugin_mode:
            audio_clip, content_clip, f0_clip = [''], [''], ['']
        else:
            # load current audio
            audio_path = row['audio_path']
            content_path = row['content_path']
            f0_path = row['f0_path']
            audio_clip, content_clip, f0_clip = self.get_audio_content(audio_path, content_path, f0_path)

        # get speaker
        if self.subset == 'train':
            speaker = row['speaker']
        else:
            speaker = row['speaker_val']

        speaker_row = self.meta[self.meta['speaker'] == speaker].sample(1)
        speaker_path = speaker_row.iloc[0]['speaker_path']
        speaker_clip = self.get_speaker(speaker_path)
        # print(speaker_clip.shape)
        # print(speaker_path)
        # print(speaker)

        # get prompt
        prompts = self.prompts[self.prompts['ID'] == speaker]['prompts'].iloc[0]
        prompts = ast.literal_eval(prompts)
        prompt = random.choice(prompts)

        return audio_clip, content_clip, f0_clip, speaker_clip, prompt

    def __len__(self):
        return len(self.meta)


if __name__ == '__main__':
    from tqdm import tqdm
    data = VCData('../../features/', '../../data/meta_val.csv', 'val', '../../data/speaker_gender.csv')
    for i in tqdm(range(len(data))):
        x = data[i]
        # print(x[-1])