import torch from torch.utils.data import Dataset, DataLoader import numpy as np import os import random from typing import Dict, List, Optional, Tuple from .audio_processor import AudioProcessor from ..configs.config import AudioConfig, Config class SpeakerDataset(Dataset): """ 说话人数据集:用于加载单个说话人的音频数据 """ def __init__( self, audio_files: List[str], audio_processor: AudioProcessor, cache_size: int = 100 # 添加缓存机制 ): self.audio_files = audio_files self.audio_processor = audio_processor self.cache = {} self.cache_size = cache_size def __len__(self) -> int: return len(self.audio_files) def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]: audio_path = self.audio_files[idx] # 使用缓存机制 if audio_path in self.cache: return self.cache[audio_path] try: audio, mel_spec = self.audio_processor.preprocess_audio(audio_path) item = { 'audio': torch.FloatTensor(audio), 'mel_spec': torch.FloatTensor(mel_spec), 'file_path': audio_path } # 更新缓存 if len(self.cache) < self.cache_size: self.cache[audio_path] = item return item except Exception as e: print(f"Error processing file {audio_path}: {str(e)}") # 返回数据集中的下一个有效样本 return self.__getitem__((idx + 1) % len(self)) class VoiceDatasetManager: """ 数据集管理器:负责数据集的组织和任务采样 """ def __init__( self, root_dir: str, audio_processor: Optional[AudioProcessor] = None, config: Optional[Config] = None ): self.root_dir = root_dir self.config = config or Config() self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio) self.speakers = self._scan_speakers() def _scan_speakers(self) -> Dict[str, List[str]]: speakers = {} for speaker_id in os.listdir(self.root_dir): speaker_dir = os.path.join(self.root_dir, speaker_id) if os.path.isdir(speaker_dir): audio_files = [] # 递归搜索所有子目录 for root, _, files in os.walk(speaker_dir): for file in files: if file.endswith(self.config.data.valid_audio_extensions): audio_path = os.path.join(root, file) # 验证文件是否可访问 if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0: audio_files.append(audio_path) # 只保留具有足够样本的说话人 if len(audio_files) >= self.config.data.min_samples_per_speaker: speakers[speaker_id] = audio_files else: print(f"Warning: Speaker {speaker_id} has insufficient samples") return speakers def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset: """获取特定说话人的数据集""" if speaker_id not in self.speakers: raise ValueError(f"Speaker {speaker_id} not found in dataset") return SpeakerDataset( self.speakers[speaker_id], self.audio_processor, cache_size=self.config.data.cache_size ) class MetaLearningDataset(Dataset): """ 元学习数据集:用于少样本语音克隆的训练 每次返回一个任务的数据,包含支持集和查询集 """ def __init__( self, dataset_manager: VoiceDatasetManager, config: Config ): self.dataset_manager = dataset_manager self.config = config # 验证数据集 available_speakers = [ speaker_id for speaker_id, files in dataset_manager.speakers.items() if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query) ] if len(available_speakers) < config.meta_learning.n_way: raise ValueError( f"Not enough speakers with sufficient samples. " f"Need {config.meta_learning.n_way} speakers with " f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, " f"but only found {len(available_speakers)}" ) self.available_speakers = available_speakers def __len__(self) -> int: return self.config.meta_learning.n_tasks def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]: """ 返回一个任务的数据 Returns: support_data: 包含支持集数据的字典 - mel_spec: [n_way*k_shot, n_mels, time] - speaker_ids: [n_way*k_shot] query_data: 包含查询集数据的字典 - mel_spec: [n_way*k_query, n_mels, time] - speaker_ids: [n_way*k_query] """ # 随机选择说话人 selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way) support_data = { 'mel_spec': [], 'speaker_ids': [] } query_data = { 'mel_spec': [], 'speaker_ids': [] } for speaker_idx, speaker_id in enumerate(selected_speakers): speaker_files = self.dataset_manager.speakers[speaker_id] selected_files = random.sample( speaker_files, self.config.meta_learning.k_shot + self.config.meta_learning.k_query ) for i, file_path in enumerate(selected_files): try: _, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path) mel_tensor = torch.FloatTensor(mel_spec) # [n_mels, time] target_dict = support_data if i < self.config.meta_learning.k_shot else query_data target_dict['mel_spec'].append(mel_tensor) target_dict['speaker_ids'].append(speaker_idx) except Exception as e: print(f"Error processing {file_path}: {str(e)}") continue # 转换为张量 for data_dict in [support_data, query_data]: if len(data_dict['mel_spec']) == 0: raise RuntimeError("No valid samples found for task") data_dict['mel_spec'] = torch.stack(data_dict['mel_spec']) data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids']) return support_data, query_data def create_meta_learning_dataloader( root_dir: str, config: Optional[Config] = None, **kwargs ) -> DataLoader: """ 创建用于元学习的数据加载器 Args: root_dir: 数据集根目录 config: 配置对象 **kwargs: 其他参数 Returns: DataLoader: 元学习数据加载器 """ config = config or Config() # 更新配置 for key, value in kwargs.items(): if hasattr(config.meta_learning, key): setattr(config.meta_learning, key, value) # 创建数据集管理器 dataset_manager = VoiceDatasetManager(root_dir, config=config) # 创建数据集 dataset = MetaLearningDataset(dataset_manager, config) # 创建数据加载器 return DataLoader( dataset, batch_size=1, # 固定为1,因为每个样本已经包含了一个完整的任务 shuffle=True, num_workers=0, # 避免多进程带来的问题 pin_memory=True, collate_fn=lambda x: x[0] # 移除批次维度 )