File size: 8,182 Bytes
9580089
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import os
import random
from typing import Dict, List, Optional, Tuple
from .audio_processor import AudioProcessor
from ..configs.config import AudioConfig, Config

class SpeakerDataset(Dataset):
    """
    说话人数据集:用于加载单个说话人的音频数据
    """
    def __init__(
        self,
        audio_files: List[str],
        audio_processor: AudioProcessor,
        cache_size: int = 100  # 添加缓存机制
    ):
        self.audio_files = audio_files
        self.audio_processor = audio_processor
        self.cache = {}
        self.cache_size = cache_size

    def __len__(self) -> int:
        return len(self.audio_files)

    def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
        audio_path = self.audio_files[idx]
        
        # 使用缓存机制
        if audio_path in self.cache:
            return self.cache[audio_path]
            
        try:
            audio, mel_spec = self.audio_processor.preprocess_audio(audio_path)
            item = {
                'audio': torch.FloatTensor(audio),
                'mel_spec': torch.FloatTensor(mel_spec),
                'file_path': audio_path
            }
            
            # 更新缓存
            if len(self.cache) < self.cache_size:
                self.cache[audio_path] = item
                
            return item
        except Exception as e:
            print(f"Error processing file {audio_path}: {str(e)}")
            # 返回数据集中的下一个有效样本
            return self.__getitem__((idx + 1) % len(self))

class VoiceDatasetManager:
    """
    数据集管理器:负责数据集的组织和任务采样
    """
    def __init__(
        self,
        root_dir: str,
        audio_processor: Optional[AudioProcessor] = None,
        config: Optional[Config] = None
    ):
        self.root_dir = root_dir
        self.config = config or Config()
        self.audio_processor = audio_processor or AudioProcessor(config=self.config.audio)
        self.speakers = self._scan_speakers()
        
    def _scan_speakers(self) -> Dict[str, List[str]]:
        speakers = {}
        for speaker_id in os.listdir(self.root_dir):
            speaker_dir = os.path.join(self.root_dir, speaker_id)
            if os.path.isdir(speaker_dir):
                audio_files = []
                # 递归搜索所有子目录
                for root, _, files in os.walk(speaker_dir):
                    for file in files:
                        if file.endswith(self.config.data.valid_audio_extensions):
                            audio_path = os.path.join(root, file)
                            # 验证文件是否可访问
                            if os.path.exists(audio_path) and os.path.getsize(audio_path) > 0:
                                audio_files.append(audio_path)
                                
                # 只保留具有足够样本的说话人
                if len(audio_files) >= self.config.data.min_samples_per_speaker:
                    speakers[speaker_id] = audio_files
                else:
                    print(f"Warning: Speaker {speaker_id} has insufficient samples")
        return speakers

    def get_speaker_dataset(self, speaker_id: str) -> SpeakerDataset:
        """获取特定说话人的数据集"""
        if speaker_id not in self.speakers:
            raise ValueError(f"Speaker {speaker_id} not found in dataset")
        return SpeakerDataset(
            self.speakers[speaker_id],
            self.audio_processor,
            cache_size=self.config.data.cache_size
        )

class MetaLearningDataset(Dataset):
    """
    元学习数据集:用于少样本语音克隆的训练
    每次返回一个任务的数据,包含支持集和查询集
    """
    def __init__(
        self,
        dataset_manager: VoiceDatasetManager,
        config: Config
    ):
        self.dataset_manager = dataset_manager
        self.config = config
        
        # 验证数据集
        available_speakers = [
            speaker_id for speaker_id, files in dataset_manager.speakers.items()
            if len(files) >= (config.meta_learning.k_shot + config.meta_learning.k_query)
        ]
        
        if len(available_speakers) < config.meta_learning.n_way:
            raise ValueError(
                f"Not enough speakers with sufficient samples. "
                f"Need {config.meta_learning.n_way} speakers with "
                f"{config.meta_learning.k_shot + config.meta_learning.k_query} samples each, "
                f"but only found {len(available_speakers)}"
            )
            
        self.available_speakers = available_speakers

    def __len__(self) -> int:
        return self.config.meta_learning.n_tasks

    def __getitem__(self, idx: int) -> Tuple[Dict[str, torch.Tensor], Dict[str, torch.Tensor]]:
        """
        返回一个任务的数据
        Returns:
            support_data: 包含支持集数据的字典
                - mel_spec: [n_way*k_shot, n_mels, time]
                - speaker_ids: [n_way*k_shot]
            query_data: 包含查询集数据的字典
                - mel_spec: [n_way*k_query, n_mels, time]
                - speaker_ids: [n_way*k_query]
        """
        # 随机选择说话人
        selected_speakers = random.sample(self.available_speakers, self.config.meta_learning.n_way)
        
        support_data = {
            'mel_spec': [],
            'speaker_ids': []
        }
        
        query_data = {
            'mel_spec': [],
            'speaker_ids': []
        }
        
        for speaker_idx, speaker_id in enumerate(selected_speakers):
            speaker_files = self.dataset_manager.speakers[speaker_id]
            selected_files = random.sample(
                speaker_files,
                self.config.meta_learning.k_shot + self.config.meta_learning.k_query
            )
            
            for i, file_path in enumerate(selected_files):
                try:
                    _, mel_spec = self.dataset_manager.audio_processor.preprocess_audio(file_path)
                    mel_tensor = torch.FloatTensor(mel_spec)  # [n_mels, time]
                    
                    target_dict = support_data if i < self.config.meta_learning.k_shot else query_data
                    target_dict['mel_spec'].append(mel_tensor)
                    target_dict['speaker_ids'].append(speaker_idx)
                    
                except Exception as e:
                    print(f"Error processing {file_path}: {str(e)}")
                    continue
        
        # 转换为张量
        for data_dict in [support_data, query_data]:
            if len(data_dict['mel_spec']) == 0:
                raise RuntimeError("No valid samples found for task")
                
            data_dict['mel_spec'] = torch.stack(data_dict['mel_spec'])
            data_dict['speaker_ids'] = torch.LongTensor(data_dict['speaker_ids'])
            
        return support_data, query_data

def create_meta_learning_dataloader(
    root_dir: str,
    config: Optional[Config] = None,
    **kwargs
) -> DataLoader:
    """
    创建用于元学习的数据加载器
    
    Args:
        root_dir: 数据集根目录
        config: 配置对象
        **kwargs: 其他参数
        
    Returns:
        DataLoader: 元学习数据加载器
    """
    config = config or Config()
    
    # 更新配置
    for key, value in kwargs.items():
        if hasattr(config.meta_learning, key):
            setattr(config.meta_learning, key, value)
    
    # 创建数据集管理器
    dataset_manager = VoiceDatasetManager(root_dir, config=config)
    
    # 创建数据集
    dataset = MetaLearningDataset(dataset_manager, config)
    
    # 创建数据加载器
    return DataLoader(
        dataset,
        batch_size=1,  # 固定为1,因为每个样本已经包含了一个完整的任务
        shuffle=True,
        num_workers=0,  # 避免多进程带来的问题
        pin_memory=True,
        collate_fn=lambda x: x[0]  # 移除批次维度
    )