File size: 1,826 Bytes
b3fb4dd
 
 
49ebc1f
b3fb4dd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
49ebc1f
b3fb4dd
 
49ebc1f
b3fb4dd
 
 
 
 
 
49ebc1f
 
 
 
 
 
 
b3fb4dd
 
 
49ebc1f
b3fb4dd
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
import torch
from torch.utils.data import IterableDataset
from torch.fft import fft
import torch.nn.functional as F
from itertools import tee
import random
import torchaudio.transforms as T


class SplitDataset(IterableDataset):
    def __init__(self, dataset, is_train=True, train_ratio=0.8):
        self.dataset = dataset
        self.is_train = is_train
        self.train_ratio = train_ratio
        
    def __iter__(self):
        count = 0
        for item in self.dataset:
            # For first train_ratio portion of items, yield to train
            # For remaining items, yield to validation
            is_train_item = count < int(self.train_ratio * 100)
            if is_train_item == self.is_train:
                yield item
            count = (count + 1) % 100


class FFTDataset(IterableDataset):
    def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000):
        self.dataset = original_dataset
        self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate)
        self.max_len = max_len
        
    def __iter__(self):
        for item in self.dataset:
            # Assuming your audio data is in item['audio']
            # Modify this based on your actual data structure
            audio_data = torch.tensor(item['audio']['array']).float()
            # pad audio
            # if len(audio_data) == 0:
            #     continue
            pad_len = self.max_len - len(audio_data)
            audio_data = F.pad(audio_data, (0, pad_len), mode='constant')
            audio_data = self.resampler(audio_data)
            fft_data = fft(audio_data)
            
            # Update the item with FFT data
            item['audio']['fft'] = fft_data
            item['audio']['array'] = audio_data
            yield item