Gregniuki commited on
Commit
5b1e4df
·
verified ·
1 Parent(s): 86472e7

Delete model/dataset.py

Browse files
Files changed (1) hide show
  1. model/dataset.py +0 -314
model/dataset.py DELETED
@@ -1,314 +0,0 @@
1
- import json
2
- import random
3
- from importlib.resources import files
4
-
5
- import torch
6
- import torch.nn.functional as F
7
- import torchaudio
8
- from datasets import Dataset as Dataset_
9
- from datasets import load_from_disk
10
- from torch import nn
11
- from torch.utils.data import Dataset, Sampler
12
- from tqdm import tqdm
13
-
14
- from f5_tts.model.modules import MelSpec
15
- from f5_tts.model.utils import default
16
-
17
-
18
- class HFDataset(Dataset):
19
- def __init__(
20
- self,
21
- hf_dataset: Dataset,
22
- target_sample_rate=24_000,
23
- n_mel_channels=100,
24
- hop_length=256,
25
- n_fft=1024,
26
- win_length=1024,
27
- mel_spec_type="vocos",
28
- ):
29
- self.data = hf_dataset
30
- self.target_sample_rate = target_sample_rate
31
- self.hop_length = hop_length
32
-
33
- self.mel_spectrogram = MelSpec(
34
- n_fft=n_fft,
35
- hop_length=hop_length,
36
- win_length=win_length,
37
- n_mel_channels=n_mel_channels,
38
- target_sample_rate=target_sample_rate,
39
- mel_spec_type=mel_spec_type,
40
- )
41
-
42
- def get_frame_len(self, index):
43
- row = self.data[index]
44
- audio = row["audio"]["array"]
45
- sample_rate = row["audio"]["sampling_rate"]
46
- return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
47
-
48
- def __len__(self):
49
- return len(self.data)
50
-
51
- def __getitem__(self, index):
52
- row = self.data[index]
53
- audio = row["audio"]["array"]
54
-
55
- # logger.info(f"Audio shape: {audio.shape}")
56
-
57
- sample_rate = row["audio"]["sampling_rate"]
58
- duration = audio.shape[-1] / sample_rate
59
-
60
- if duration > 30 or duration < 0.3:
61
- return self.__getitem__((index + 1) % len(self.data))
62
-
63
- audio_tensor = torch.from_numpy(audio).float()
64
-
65
- if sample_rate != self.target_sample_rate:
66
- resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
67
- audio_tensor = resampler(audio_tensor)
68
-
69
- audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
70
-
71
- mel_spec = self.mel_spectrogram(audio_tensor)
72
-
73
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
74
-
75
- text = row["text"]
76
-
77
- return dict(
78
- mel_spec=mel_spec,
79
- text=text,
80
- )
81
-
82
-
83
- class CustomDataset(Dataset):
84
- def __init__(
85
- self,
86
- custom_dataset: Dataset,
87
- durations=None,
88
- target_sample_rate=24_000,
89
- hop_length=256,
90
- n_mel_channels=100,
91
- n_fft=1024,
92
- win_length=1024,
93
- mel_spec_type="vocos",
94
- preprocessed_mel=False,
95
- mel_spec_module: nn.Module | None = None,
96
- ):
97
- self.data = custom_dataset
98
- self.durations = durations
99
- self.target_sample_rate = target_sample_rate
100
- self.hop_length = hop_length
101
- self.n_fft = n_fft
102
- self.win_length = win_length
103
- self.mel_spec_type = mel_spec_type
104
- self.preprocessed_mel = preprocessed_mel
105
-
106
- if not preprocessed_mel:
107
- self.mel_spectrogram = default(
108
- mel_spec_module,
109
- MelSpec(
110
- n_fft=n_fft,
111
- hop_length=hop_length,
112
- win_length=win_length,
113
- n_mel_channels=n_mel_channels,
114
- target_sample_rate=target_sample_rate,
115
- mel_spec_type=mel_spec_type,
116
- ),
117
- )
118
-
119
- def get_frame_len(self, index):
120
- if (
121
- self.durations is not None
122
- ): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
123
- return self.durations[index] * self.target_sample_rate / self.hop_length
124
- return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
125
-
126
- def __len__(self):
127
- return len(self.data)
128
-
129
- def __getitem__(self, index):
130
- row = self.data[index]
131
- audio_path = row["audio_path"]
132
- text = row["text"]
133
- duration = row["duration"]
134
-
135
- if self.preprocessed_mel:
136
- mel_spec = torch.tensor(row["mel_spec"])
137
-
138
- else:
139
- audio, source_sample_rate = torchaudio.load(audio_path)
140
- if audio.shape[0] > 1:
141
- audio = torch.mean(audio, dim=0, keepdim=True)
142
-
143
- if duration > 30 or duration < 0.3:
144
- return self.__getitem__((index + 1) % len(self.data))
145
-
146
- if source_sample_rate != self.target_sample_rate:
147
- resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
148
- audio = resampler(audio)
149
-
150
- mel_spec = self.mel_spectrogram(audio)
151
- mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
152
-
153
- return dict(
154
- mel_spec=mel_spec,
155
- text=text,
156
- )
157
-
158
-
159
- # Dynamic Batch Sampler
160
-
161
-
162
- class DynamicBatchSampler(Sampler[list[int]]):
163
- """Extension of Sampler that will do the following:
164
- 1. Change the batch size (essentially number of sequences)
165
- in a batch to ensure that the total number of frames are less
166
- than a certain threshold.
167
- 2. Make sure the padding efficiency in the batch is high.
168
- """
169
-
170
- def __init__(
171
- self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
172
- ):
173
- self.sampler = sampler
174
- self.frames_threshold = frames_threshold
175
- self.max_samples = max_samples
176
-
177
- indices, batches = [], []
178
- data_source = self.sampler.data_source
179
-
180
- for idx in tqdm(
181
- self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
182
- ):
183
- indices.append((idx, data_source.get_frame_len(idx)))
184
- indices.sort(key=lambda elem: elem[1])
185
-
186
- batch = []
187
- batch_frames = 0
188
- for idx, frame_len in tqdm(
189
- indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
190
- ):
191
- if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
192
- batch.append(idx)
193
- batch_frames += frame_len
194
- else:
195
- if len(batch) > 0:
196
- batches.append(batch)
197
- if frame_len <= self.frames_threshold:
198
- batch = [idx]
199
- batch_frames = frame_len
200
- else:
201
- batch = []
202
- batch_frames = 0
203
-
204
- if not drop_last and len(batch) > 0:
205
- batches.append(batch)
206
-
207
- del indices
208
-
209
- # if want to have different batches between epochs, may just set a seed and log it in ckpt
210
- # cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
211
- # e.g. for epoch n, use (random_seed + n)
212
- random.seed(random_seed)
213
- random.shuffle(batches)
214
-
215
- self.batches = batches
216
-
217
- def __iter__(self):
218
- return iter(self.batches)
219
-
220
- def __len__(self):
221
- return len(self.batches)
222
-
223
-
224
- # Load dataset
225
-
226
-
227
- def load_dataset(
228
- dataset_name: str,
229
- tokenizer: str = "pinyin",
230
- dataset_type: str = "CustomDataset",
231
- audio_type: str = "raw",
232
- mel_spec_module: nn.Module | None = None,
233
- mel_spec_kwargs: dict = dict(),
234
- ) -> CustomDataset | HFDataset:
235
- """
236
- dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
237
- - "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
238
- """
239
-
240
- print("Loading dataset ...")
241
-
242
- if dataset_type == "CustomDataset":
243
- rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
244
- if audio_type == "raw":
245
- try:
246
- train_dataset = load_from_disk(f"{rel_data_path}/raw")
247
- except: # noqa: E722
248
- train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
249
- preprocessed_mel = False
250
- elif audio_type == "mel":
251
- train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
252
- preprocessed_mel = True
253
- with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
254
- data_dict = json.load(f)
255
- durations = data_dict["duration"]
256
- train_dataset = CustomDataset(
257
- train_dataset,
258
- durations=durations,
259
- preprocessed_mel=preprocessed_mel,
260
- mel_spec_module=mel_spec_module,
261
- **mel_spec_kwargs,
262
- )
263
-
264
- elif dataset_type == "CustomDatasetPath":
265
- try:
266
- train_dataset = load_from_disk(f"{dataset_name}/raw")
267
- except: # noqa: E722
268
- train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
269
-
270
- with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
271
- data_dict = json.load(f)
272
- durations = data_dict["duration"]
273
- train_dataset = CustomDataset(
274
- train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
275
- )
276
-
277
- elif dataset_type == "HFDataset":
278
- print(
279
- "Should manually modify the path of huggingface dataset to your need.\n"
280
- + "May also the corresponding script cuz different dataset may have different format."
281
- )
282
- pre, post = dataset_name.split("_")
283
- train_dataset = HFDataset(
284
- load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
285
- )
286
-
287
- return train_dataset
288
-
289
-
290
- # collation
291
-
292
-
293
- def collate_fn(batch):
294
- mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
295
- mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
296
- max_mel_length = mel_lengths.amax()
297
-
298
- padded_mel_specs = []
299
- for spec in mel_specs: # TODO. maybe records mask for attention here
300
- padding = (0, max_mel_length - spec.size(-1))
301
- padded_spec = F.pad(spec, padding, value=0)
302
- padded_mel_specs.append(padded_spec)
303
-
304
- mel_specs = torch.stack(padded_mel_specs)
305
-
306
- text = [item["text"] for item in batch]
307
- text_lengths = torch.LongTensor([len(item) for item in text])
308
-
309
- return dict(
310
- mel=mel_specs,
311
- mel_lengths=mel_lengths,
312
- text=text,
313
- text_lengths=text_lengths,
314
- )