Spaces:
Runtime error
Runtime error
File size: 2,105 Bytes
7865f10 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 |
from typing import Dict, Any
import torch
import librosa
import numpy as np
from datasets import Dataset
from ..cloning.model import CloningModel
from ..transcriber.model import TranscriberModel
def prepare_dataset(example: Dict[str, Any], model: CloningModel) -> Dict[str, Any]:
"""
Prepare a single example for training
"""
# feature extraction and tokenization
processed_example = model.processor(
text=example["normalized_text"],
audio_target=example["audio"]["array"],
sampling_rate=16000,
return_attention_mask=False,
)
# strip off the batch dimension
if len(torch.tensor(processed_example['input_ids']).shape) > 1:
processed_example['input_ids'] = processed_example['input_ids'][0]
processed_example["labels"] = processed_example["labels"][0]
# use SpeechBrain to obtain x-vector
processed_example["speaker_embeddings"] = model.create_speaker_embedding(
torch.tensor(example["audio"]["array"])
).numpy()
return processed_example
def get_cloning_dataset(input_audio_path: str,
transcriber_model: TranscriberModel,
cloning_model: CloningModel,
sampling_rate: int = 16000,
window_size_secs: int = 5) -> Dataset:
"""
Create dataset by transcribing an audio file using a pretrained Wav2Vec2 model.
"""
speech_array, _ = librosa.load(input_audio_path, sr=sampling_rate)
# split a waveform into splits of 5 secs each
speech_arrays = np.split(speech_array, range(0, len(speech_array), window_size_secs * sampling_rate))[1:]
texts = [transcriber_model.forward(speech_array, sampling_rate=sampling_rate)
for speech_array in speech_arrays]
dataset = Dataset.from_list([
{'audio': {'array': speech_arrays[i]}, 'normalized_text': texts[i]}
for i in range(len(speech_arrays))]
)
dataset = dataset.map(
prepare_dataset, fn_kwargs={'model': cloning_model},
remove_columns=dataset.column_names,
)
return dataset
|