import torch from torch.utils.data import IterableDataset from torch.fft import fft import torch.nn.functional as F from itertools import tee import random import torchaudio.transforms as T class SplitDataset(IterableDataset): def __init__(self, dataset, is_train=True, train_ratio=0.8): self.dataset = dataset self.is_train = is_train self.train_ratio = train_ratio def __iter__(self): count = 0 for item in self.dataset: # For first train_ratio portion of items, yield to train # For remaining items, yield to validation is_train_item = count < int(self.train_ratio * 100) if is_train_item == self.is_train: yield item count = (count + 1) % 100 class FFTDataset(IterableDataset): def __init__(self, original_dataset, max_len=72000, orig_sample_rate=12000, target_sample_rate=3000): self.dataset = original_dataset self.resampler = T.Resample(orig_freq=orig_sample_rate, new_freq=target_sample_rate) self.max_len = max_len def __iter__(self): for item in self.dataset: # Assuming your audio data is in item['audio'] # Modify this based on your actual data structure audio_data = torch.tensor(item['audio']['array']).float() # pad audio # if len(audio_data) == 0: # continue pad_len = self.max_len - len(audio_data) audio_data = F.pad(audio_data, (0, pad_len), mode='constant') audio_data = self.resampler(audio_data) fft_data = fft(audio_data) # Update the item with FFT data item['audio']['fft'] = fft_data item['audio']['array'] = audio_data yield item