Spaces:
Sleeping
Sleeping
File size: 8,182 Bytes
9580089 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import random
from typing import Dict, List, Optional, Tuple
from .audio_processor import AudioProcessor
from ..configs.config import AudioConfig, Config
class SpeakerDataset(Dataset):
"""
说话人数据集:用于加载单个说话人的音频数据
"""
def __init__(
self,
audio_files: List[str],
audio_processor: AudioProcessor,
cache_size: int = 100 # 添加缓存机制
):
self.audio_files = audio_files
self.audio_processor = audio_processor
self.cache = {}
self.cache_size = cache_size
def __len__(self) -> int:
return len(self.audio_files)
def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
audio_path = self.audio_files[idx]
# 使用缓存机制
if audio_path in self.cache:
return self.cache[audio_path]
try:
audio, mel_spec = self.audio_processor.preprocess_audio(audio_path)
item = {
'audio': torch.FloatTensor(audio),
'mel_spec': torch.FloatTensor(mel_spec),
'file_path': audio_path
}
# 更新缓存
if len(self.cache) < self.cache_size:
self.cache[audio_path] = item
return item
except Exception as e:
print(f"Error processing file {audio_path}: {str(e)}")
# 返回数据集中的下一个有效样本
return self.__getitem__((idx + 1) % len(self))
class VoiceDatasetManager:
"""
数据集管理器:负责数据集的组织和任务采样
"""
def __init__(
self,
root_dir: str,
audio_processor: Optional[AudioProcessor] = None,
config: Optional[Config] = None
):
self.root_dir = root_dir
self.config = config or Config()
self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio)
self.speakers = self._scan_speakers()
def _scan_speakers(self) -> Dict[str, List[str]]:
speakers = {}
for speaker_id in os.listdir(self.root_dir):
speaker_dir = os.path.join(self.root_dir, speaker_id)
if os.path.isdir(speaker_dir):
audio_files = []
# 递归搜索所有子目录
for root, _, files in os.walk(speaker_dir):
for file in files:
if file.endswith(self.config.data.valid_audio_extensions):
audio_path = os.path.join(root, file)
# 验证文件是否可访问
if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
audio_files.append(audio_path)
# 只保留具有足够样本的说话人
if len(audio_files) >= self.config.data.min_samples_per_speaker:
speakers[speaker_id] = audio_files
else:
print(f"Warning: Speaker {speaker_id} has insufficient samples")
return speakers
def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset:
"""获取特定说话人的数据集"""
if speaker_id not in self.speakers:
raise ValueError(f"Speaker {speaker_id} not found in dataset")
return SpeakerDataset(
self.speakers[speaker_id],
self.audio_processor,
cache_size=self.config.data.cache_size
)
class MetaLearningDataset(Dataset):
"""
元学习数据集:用于少样本语音克隆的训练
每次返回一个任务的数据,包含支持集和查询集
"""
def __init__(
self,
dataset_manager: VoiceDatasetManager,
config: Config
):
self.dataset_manager = dataset_manager
self.config = config
# 验证数据集
available_speakers = [
speaker_id for speaker_id, files in dataset_manager.speakers.items()
if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query)
]
if len(available_speakers) < config.meta_learning.n_way:
raise ValueError(
f"Not enough speakers with sufficient samples. "
f"Need {config.meta_learning.n_way} speakers with "
f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, "
f"but only found {len(available_speakers)}"
)
self.available_speakers = available_speakers
def __len__(self) -> int:
return self.config.meta_learning.n_tasks
def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
"""
返回一个任务的数据
Returns:
support_data: 包含支持集数据的字典
- mel_spec: [n_way*k_shot, n_mels, time]
- speaker_ids: [n_way*k_shot]
query_data: 包含查询集数据的字典
- mel_spec: [n_way*k_query, n_mels, time]
- speaker_ids: [n_way*k_query]
"""
# 随机选择说话人
selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way)
support_data = {
'mel_spec': [],
'speaker_ids': []
}
query_data = {
'mel_spec': [],
'speaker_ids': []
}
for speaker_idx, speaker_id in enumerate(selected_speakers):
speaker_files = self.dataset_manager.speakers[speaker_id]
selected_files = random.sample(
speaker_files,
self.config.meta_learning.k_shot + self.config.meta_learning.k_query
)
for i, file_path in enumerate(selected_files):
try:
_, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path)
mel_tensor = torch.FloatTensor(mel_spec) # [n_mels, time]
target_dict = support_data if i < self.config.meta_learning.k_shot else query_data
target_dict['mel_spec'].append(mel_tensor)
target_dict['speaker_ids'].append(speaker_idx)
except Exception as e:
print(f"Error processing {file_path}: {str(e)}")
continue
# 转换为张量
for data_dict in [support_data, query_data]:
if len(data_dict['mel_spec']) == 0:
raise RuntimeError("No valid samples found for task")
data_dict['mel_spec'] = torch.stack(data_dict['mel_spec'])
data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids'])
return support_data, query_data
def create_meta_learning_dataloader(
root_dir: str,
config: Optional[Config] = None,
**kwargs
) -> DataLoader:
"""
创建用于元学习的数据加载器
Args:
root_dir: 数据集根目录
config: 配置对象
**kwargs: 其他参数
Returns:
DataLoader: 元学习数据加载器
"""
config = config or Config()
# 更新配置
for key, value in kwargs.items():
if hasattr(config.meta_learning, key):
setattr(config.meta_learning, key, value)
# 创建数据集管理器
dataset_manager = VoiceDatasetManager(root_dir, config=config)
# 创建数据集
dataset = MetaLearningDataset(dataset_manager, config)
# 创建数据加载器
return DataLoader(
dataset,
batch_size=1, # 固定为1,因为每个样本已经包含了一个完整的任务
shuffle=True,
num_workers=0, # 避免多进程带来的问题
pin_memory=True,
collate_fn=lambda x: x[0] # 移除批次维度
) |