File size: 2,332 Bytes
5a9b731
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
""" 
music_datasets.py
    Desc: Contains the code for the music datasets. 
"""

import torch 
from torch.utils.data import Dataset 
import torchaudio 
import numpy as np 
import pandas as pd 


""" 
MusicMelDataset:
    Given pre-processed mel-spectrograms, return a chunk of audio from the mel, with a masked version of a defined length 
    Args:
        audio_files: List of .npy files consisting of mel-specs
        audio_len: length in seconds (roughly) of audio to be return
        mask_ratio: Size of mask as a ration of audio_len
        mask_start: Where the mask starts for learning
            "midpoint": always mask out the second half of the mel-spec
        crop_start: Where the starting point for the sample of audio is taken
            "random": Random valid starting point from audio is taken

"""
class MusicMelDataset(Dataset):
    def __init__(self, audio_files, audio_len = 6, mask_ratio = 0.5, mask_start = "midpoint", crop_start = "random"):
        self.audio_files = audio_files 

        # Convert length to number of frames
        self.audio_len = int(audio_len * 100) # 100 is heuristic conversion made
        self.mask_ratio = mask_ratio
        self.mask_len = int(np.floor(self.audio_len * mask_ratio))
        self.mask_start = mask_start 
        self.crop_start = crop_start

    def __len__(self):
        return len(self.audio_files)

    # Get a random crop using audio_length
    def get_random_crop(self, mel):
        crop_start = torch.randint(0, mel.shape[0] - self.audio_len - 1, (1,))
        return mel[crop_start:crop_start + self.audio_len, :]

    def __getitem__(self, idx):
        mel = torch.Tensor(np.load(self.audio_files[idx]))


        if self.crop_start == "random":
            mel = self.get_random_crop(mel)
        else:
            raise NotImplementedError(f"{self.crop_start} is not an implemented parameter for crop_start")
        
        mask = torch.ones_like(mel)
        if self.mask_start == "midpoint":
            if self.mask_ratio == 0.5:
                mask[self.mask_len:, :] = 0
            else:
                mask[self.audio_len // 2 + self.mask_len, :] = 0
        else: 
            raise NotImplementedError(f"{self.mask_start} is not an implemented parameter for mask_start")

        mel_mask = mel*mask

        return mel, mel_mask