Spaces:
Running
Running
Delete model/dataset.py
Browse files- model/dataset.py +0 -314
model/dataset.py
DELETED
@@ -1,314 +0,0 @@
|
|
1 |
-
import json
|
2 |
-
import random
|
3 |
-
from importlib.resources import files
|
4 |
-
|
5 |
-
import torch
|
6 |
-
import torch.nn.functional as F
|
7 |
-
import torchaudio
|
8 |
-
from datasets import Dataset as Dataset_
|
9 |
-
from datasets import load_from_disk
|
10 |
-
from torch import nn
|
11 |
-
from torch.utils.data import Dataset, Sampler
|
12 |
-
from tqdm import tqdm
|
13 |
-
|
14 |
-
from f5_tts.model.modules import MelSpec
|
15 |
-
from f5_tts.model.utils import default
|
16 |
-
|
17 |
-
|
18 |
-
class HFDataset(Dataset):
|
19 |
-
def __init__(
|
20 |
-
self,
|
21 |
-
hf_dataset: Dataset,
|
22 |
-
target_sample_rate=24_000,
|
23 |
-
n_mel_channels=100,
|
24 |
-
hop_length=256,
|
25 |
-
n_fft=1024,
|
26 |
-
win_length=1024,
|
27 |
-
mel_spec_type="vocos",
|
28 |
-
):
|
29 |
-
self.data = hf_dataset
|
30 |
-
self.target_sample_rate = target_sample_rate
|
31 |
-
self.hop_length = hop_length
|
32 |
-
|
33 |
-
self.mel_spectrogram = MelSpec(
|
34 |
-
n_fft=n_fft,
|
35 |
-
hop_length=hop_length,
|
36 |
-
win_length=win_length,
|
37 |
-
n_mel_channels=n_mel_channels,
|
38 |
-
target_sample_rate=target_sample_rate,
|
39 |
-
mel_spec_type=mel_spec_type,
|
40 |
-
)
|
41 |
-
|
42 |
-
def get_frame_len(self, index):
|
43 |
-
row = self.data[index]
|
44 |
-
audio = row["audio"]["array"]
|
45 |
-
sample_rate = row["audio"]["sampling_rate"]
|
46 |
-
return audio.shape[-1] / sample_rate * self.target_sample_rate / self.hop_length
|
47 |
-
|
48 |
-
def __len__(self):
|
49 |
-
return len(self.data)
|
50 |
-
|
51 |
-
def __getitem__(self, index):
|
52 |
-
row = self.data[index]
|
53 |
-
audio = row["audio"]["array"]
|
54 |
-
|
55 |
-
# logger.info(f"Audio shape: {audio.shape}")
|
56 |
-
|
57 |
-
sample_rate = row["audio"]["sampling_rate"]
|
58 |
-
duration = audio.shape[-1] / sample_rate
|
59 |
-
|
60 |
-
if duration > 30 or duration < 0.3:
|
61 |
-
return self.__getitem__((index + 1) % len(self.data))
|
62 |
-
|
63 |
-
audio_tensor = torch.from_numpy(audio).float()
|
64 |
-
|
65 |
-
if sample_rate != self.target_sample_rate:
|
66 |
-
resampler = torchaudio.transforms.Resample(sample_rate, self.target_sample_rate)
|
67 |
-
audio_tensor = resampler(audio_tensor)
|
68 |
-
|
69 |
-
audio_tensor = audio_tensor.unsqueeze(0) # 't -> 1 t')
|
70 |
-
|
71 |
-
mel_spec = self.mel_spectrogram(audio_tensor)
|
72 |
-
|
73 |
-
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t'
|
74 |
-
|
75 |
-
text = row["text"]
|
76 |
-
|
77 |
-
return dict(
|
78 |
-
mel_spec=mel_spec,
|
79 |
-
text=text,
|
80 |
-
)
|
81 |
-
|
82 |
-
|
83 |
-
class CustomDataset(Dataset):
|
84 |
-
def __init__(
|
85 |
-
self,
|
86 |
-
custom_dataset: Dataset,
|
87 |
-
durations=None,
|
88 |
-
target_sample_rate=24_000,
|
89 |
-
hop_length=256,
|
90 |
-
n_mel_channels=100,
|
91 |
-
n_fft=1024,
|
92 |
-
win_length=1024,
|
93 |
-
mel_spec_type="vocos",
|
94 |
-
preprocessed_mel=False,
|
95 |
-
mel_spec_module: nn.Module | None = None,
|
96 |
-
):
|
97 |
-
self.data = custom_dataset
|
98 |
-
self.durations = durations
|
99 |
-
self.target_sample_rate = target_sample_rate
|
100 |
-
self.hop_length = hop_length
|
101 |
-
self.n_fft = n_fft
|
102 |
-
self.win_length = win_length
|
103 |
-
self.mel_spec_type = mel_spec_type
|
104 |
-
self.preprocessed_mel = preprocessed_mel
|
105 |
-
|
106 |
-
if not preprocessed_mel:
|
107 |
-
self.mel_spectrogram = default(
|
108 |
-
mel_spec_module,
|
109 |
-
MelSpec(
|
110 |
-
n_fft=n_fft,
|
111 |
-
hop_length=hop_length,
|
112 |
-
win_length=win_length,
|
113 |
-
n_mel_channels=n_mel_channels,
|
114 |
-
target_sample_rate=target_sample_rate,
|
115 |
-
mel_spec_type=mel_spec_type,
|
116 |
-
),
|
117 |
-
)
|
118 |
-
|
119 |
-
def get_frame_len(self, index):
|
120 |
-
if (
|
121 |
-
self.durations is not None
|
122 |
-
): # Please make sure the separately provided durations are correct, otherwise 99.99% OOM
|
123 |
-
return self.durations[index] * self.target_sample_rate / self.hop_length
|
124 |
-
return self.data[index]["duration"] * self.target_sample_rate / self.hop_length
|
125 |
-
|
126 |
-
def __len__(self):
|
127 |
-
return len(self.data)
|
128 |
-
|
129 |
-
def __getitem__(self, index):
|
130 |
-
row = self.data[index]
|
131 |
-
audio_path = row["audio_path"]
|
132 |
-
text = row["text"]
|
133 |
-
duration = row["duration"]
|
134 |
-
|
135 |
-
if self.preprocessed_mel:
|
136 |
-
mel_spec = torch.tensor(row["mel_spec"])
|
137 |
-
|
138 |
-
else:
|
139 |
-
audio, source_sample_rate = torchaudio.load(audio_path)
|
140 |
-
if audio.shape[0] > 1:
|
141 |
-
audio = torch.mean(audio, dim=0, keepdim=True)
|
142 |
-
|
143 |
-
if duration > 30 or duration < 0.3:
|
144 |
-
return self.__getitem__((index + 1) % len(self.data))
|
145 |
-
|
146 |
-
if source_sample_rate != self.target_sample_rate:
|
147 |
-
resampler = torchaudio.transforms.Resample(source_sample_rate, self.target_sample_rate)
|
148 |
-
audio = resampler(audio)
|
149 |
-
|
150 |
-
mel_spec = self.mel_spectrogram(audio)
|
151 |
-
mel_spec = mel_spec.squeeze(0) # '1 d t -> d t')
|
152 |
-
|
153 |
-
return dict(
|
154 |
-
mel_spec=mel_spec,
|
155 |
-
text=text,
|
156 |
-
)
|
157 |
-
|
158 |
-
|
159 |
-
# Dynamic Batch Sampler
|
160 |
-
|
161 |
-
|
162 |
-
class DynamicBatchSampler(Sampler[list[int]]):
|
163 |
-
"""Extension of Sampler that will do the following:
|
164 |
-
1. Change the batch size (essentially number of sequences)
|
165 |
-
in a batch to ensure that the total number of frames are less
|
166 |
-
than a certain threshold.
|
167 |
-
2. Make sure the padding efficiency in the batch is high.
|
168 |
-
"""
|
169 |
-
|
170 |
-
def __init__(
|
171 |
-
self, sampler: Sampler[int], frames_threshold: int, max_samples=0, random_seed=None, drop_last: bool = False
|
172 |
-
):
|
173 |
-
self.sampler = sampler
|
174 |
-
self.frames_threshold = frames_threshold
|
175 |
-
self.max_samples = max_samples
|
176 |
-
|
177 |
-
indices, batches = [], []
|
178 |
-
data_source = self.sampler.data_source
|
179 |
-
|
180 |
-
for idx in tqdm(
|
181 |
-
self.sampler, desc="Sorting with sampler... if slow, check whether dataset is provided with duration"
|
182 |
-
):
|
183 |
-
indices.append((idx, data_source.get_frame_len(idx)))
|
184 |
-
indices.sort(key=lambda elem: elem[1])
|
185 |
-
|
186 |
-
batch = []
|
187 |
-
batch_frames = 0
|
188 |
-
for idx, frame_len in tqdm(
|
189 |
-
indices, desc=f"Creating dynamic batches with {frames_threshold} audio frames per gpu"
|
190 |
-
):
|
191 |
-
if batch_frames + frame_len <= self.frames_threshold and (max_samples == 0 or len(batch) < max_samples):
|
192 |
-
batch.append(idx)
|
193 |
-
batch_frames += frame_len
|
194 |
-
else:
|
195 |
-
if len(batch) > 0:
|
196 |
-
batches.append(batch)
|
197 |
-
if frame_len <= self.frames_threshold:
|
198 |
-
batch = [idx]
|
199 |
-
batch_frames = frame_len
|
200 |
-
else:
|
201 |
-
batch = []
|
202 |
-
batch_frames = 0
|
203 |
-
|
204 |
-
if not drop_last and len(batch) > 0:
|
205 |
-
batches.append(batch)
|
206 |
-
|
207 |
-
del indices
|
208 |
-
|
209 |
-
# if want to have different batches between epochs, may just set a seed and log it in ckpt
|
210 |
-
# cuz during multi-gpu training, although the batch on per gpu not change between epochs, the formed general minibatch is different
|
211 |
-
# e.g. for epoch n, use (random_seed + n)
|
212 |
-
random.seed(random_seed)
|
213 |
-
random.shuffle(batches)
|
214 |
-
|
215 |
-
self.batches = batches
|
216 |
-
|
217 |
-
def __iter__(self):
|
218 |
-
return iter(self.batches)
|
219 |
-
|
220 |
-
def __len__(self):
|
221 |
-
return len(self.batches)
|
222 |
-
|
223 |
-
|
224 |
-
# Load dataset
|
225 |
-
|
226 |
-
|
227 |
-
def load_dataset(
|
228 |
-
dataset_name: str,
|
229 |
-
tokenizer: str = "pinyin",
|
230 |
-
dataset_type: str = "CustomDataset",
|
231 |
-
audio_type: str = "raw",
|
232 |
-
mel_spec_module: nn.Module | None = None,
|
233 |
-
mel_spec_kwargs: dict = dict(),
|
234 |
-
) -> CustomDataset | HFDataset:
|
235 |
-
"""
|
236 |
-
dataset_type - "CustomDataset" if you want to use tokenizer name and default data path to load for train_dataset
|
237 |
-
- "CustomDatasetPath" if you just want to pass the full path to a preprocessed dataset without relying on tokenizer
|
238 |
-
"""
|
239 |
-
|
240 |
-
print("Loading dataset ...")
|
241 |
-
|
242 |
-
if dataset_type == "CustomDataset":
|
243 |
-
rel_data_path = str(files("f5_tts").joinpath(f"../../data/{dataset_name}_{tokenizer}"))
|
244 |
-
if audio_type == "raw":
|
245 |
-
try:
|
246 |
-
train_dataset = load_from_disk(f"{rel_data_path}/raw")
|
247 |
-
except: # noqa: E722
|
248 |
-
train_dataset = Dataset_.from_file(f"{rel_data_path}/raw.arrow")
|
249 |
-
preprocessed_mel = False
|
250 |
-
elif audio_type == "mel":
|
251 |
-
train_dataset = Dataset_.from_file(f"{rel_data_path}/mel.arrow")
|
252 |
-
preprocessed_mel = True
|
253 |
-
with open(f"{rel_data_path}/duration.json", "r", encoding="utf-8") as f:
|
254 |
-
data_dict = json.load(f)
|
255 |
-
durations = data_dict["duration"]
|
256 |
-
train_dataset = CustomDataset(
|
257 |
-
train_dataset,
|
258 |
-
durations=durations,
|
259 |
-
preprocessed_mel=preprocessed_mel,
|
260 |
-
mel_spec_module=mel_spec_module,
|
261 |
-
**mel_spec_kwargs,
|
262 |
-
)
|
263 |
-
|
264 |
-
elif dataset_type == "CustomDatasetPath":
|
265 |
-
try:
|
266 |
-
train_dataset = load_from_disk(f"{dataset_name}/raw")
|
267 |
-
except: # noqa: E722
|
268 |
-
train_dataset = Dataset_.from_file(f"{dataset_name}/raw.arrow")
|
269 |
-
|
270 |
-
with open(f"{dataset_name}/duration.json", "r", encoding="utf-8") as f:
|
271 |
-
data_dict = json.load(f)
|
272 |
-
durations = data_dict["duration"]
|
273 |
-
train_dataset = CustomDataset(
|
274 |
-
train_dataset, durations=durations, preprocessed_mel=preprocessed_mel, **mel_spec_kwargs
|
275 |
-
)
|
276 |
-
|
277 |
-
elif dataset_type == "HFDataset":
|
278 |
-
print(
|
279 |
-
"Should manually modify the path of huggingface dataset to your need.\n"
|
280 |
-
+ "May also the corresponding script cuz different dataset may have different format."
|
281 |
-
)
|
282 |
-
pre, post = dataset_name.split("_")
|
283 |
-
train_dataset = HFDataset(
|
284 |
-
load_dataset(f"{pre}/{pre}", split=f"train.{post}", cache_dir=str(files("f5_tts").joinpath("../../data"))),
|
285 |
-
)
|
286 |
-
|
287 |
-
return train_dataset
|
288 |
-
|
289 |
-
|
290 |
-
# collation
|
291 |
-
|
292 |
-
|
293 |
-
def collate_fn(batch):
|
294 |
-
mel_specs = [item["mel_spec"].squeeze(0) for item in batch]
|
295 |
-
mel_lengths = torch.LongTensor([spec.shape[-1] for spec in mel_specs])
|
296 |
-
max_mel_length = mel_lengths.amax()
|
297 |
-
|
298 |
-
padded_mel_specs = []
|
299 |
-
for spec in mel_specs: # TODO. maybe records mask for attention here
|
300 |
-
padding = (0, max_mel_length - spec.size(-1))
|
301 |
-
padded_spec = F.pad(spec, padding, value=0)
|
302 |
-
padded_mel_specs.append(padded_spec)
|
303 |
-
|
304 |
-
mel_specs = torch.stack(padded_mel_specs)
|
305 |
-
|
306 |
-
text = [item["text"] for item in batch]
|
307 |
-
text_lengths = torch.LongTensor([len(item) for item in text])
|
308 |
-
|
309 |
-
return dict(
|
310 |
-
mel=mel_specs,
|
311 |
-
mel_lengths=mel_lengths,
|
312 |
-
text=text,
|
313 |
-
text_lengths=text_lengths,
|
314 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|