{ "cells": [ { "cell_type": "code", "execution_count": 2, "id": "df6caf53-8d7f-4a5c-a165-679b71e557b3", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", " from .autonotebook import tqdm as notebook_tqdm\n" ] } ], "source": [ "from datasets import Audio, interleave_datasets, IterableDataset, IterableDatasetDict, load_dataset\n", "from transformers import WhisperProcessor\n", "from transformers.models.whisper.english_normalizer import BasicTextNormalizer\n", "from typing import List, Optional" ] }, { "cell_type": "code", "execution_count": 3, "id": "b4fcd9ff-9f1b-4143-89f3-254bba864946", "metadata": {}, "outputs": [], "source": [ "def load_multiple_streaming_datasets(\n", " dataset_names: List,\n", " dataset_config_names: List,\n", " splits: Optional[List] = None,\n", " text_column_names: Optional[List] = None,\n", " sampling_rate: Optional[int] = 16000,\n", " stopping_strategy: Optional[str] = \"all_exhausted\",\n", " **kwargs\n", ") -> IterableDataset:\n", "\n", " if len(dataset_names) != len(dataset_config_names):\n", " raise ValueError(\n", " f\"Ensure one config is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(dataset_config_names)} configs.\"\n", " )\n", "\n", " if splits is not None and len(splits) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one split is passed for each dataset, got {len(dataset_names)} datasets and {len(splits)} splits.\"\n", " )\n", "\n", " if text_column_names is not None and len(text_column_names) != len(dataset_names):\n", " raise ValueError(\n", " f\"Ensure one text column name is passed for each dataset, got {len(dataset_names)} datasets and\"\n", " f\" {len(text_column_names)} text column names.\"\n", " )\n", "\n", " splits = splits if splits is not None else [\"train\" for i in range(len(dataset_names))]\n", " text_column_names = (\n", " text_column_names if text_column_names is not None else [\"text\" for i in range(len(dataset_names))]\n", " )\n", "\n", " all_datasets = []\n", " # iterate over the datasets we want to interleave\n", " for i, dataset_name in enumerate(dataset_names):\n", " dataset = load_dataset(dataset_name, dataset_config_names[i], split=splits[i], streaming=True, **kwargs)\n", " # resample to specified sampling rate\n", " dataset = dataset.cast_column(\"audio\", Audio(sampling_rate))\n", " # normalise columns to [\"audio\", \"sentence\"]\n", " if text_column_names[i] != \"sentence\":\n", " dataset = dataset.rename_column(text_column_names[i], \"sentence\")\n", " dataset = dataset.remove_columns(set(dataset.features.keys()) - set([\"audio\", \"sentence\"]))\n", " all_datasets.append(dataset)\n", "\n", " interleaved_dataset = interleave_datasets(all_datasets, stopping_strategy=stopping_strategy)\n", " return interleaved_dataset" ] }, { "cell_type": "code", "execution_count": 4, "id": "89fc237e-3f44-4330-8b7e-5b8c350d83d9", "metadata": {}, "outputs": [], "source": [ "dataset_names = [\"mozilla-foundation/common_voice_11_0\", \"google/fleurs\", \"openslr\", \"collectivat/tv3_parla\", \"projecte-aina/parlament_parla\", \"projecte-aina/parlament_parla\"]\n", "dataset_config_names = [\"ca\", \"ca_es\", \"SLR69\", \"ca\", \"clean\", \"other\"]\n", "text_column_names = [\"sentence\", \"transcription\", \"sentence\", \"text\", \"sentence\", \"sentence\"]" ] }, { "cell_type": "code", "execution_count": 5, "id": "c2cf8afb-5fb5-4e81-9d73-4f92ca73d395", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 8.30k/8.30k [00:00<00:00, 9.77MB/s]\n", "Downloading readme: 100%|██████████| 12.2k/12.2k [00:00<00:00, 15.9MB/s]\n", "Downloading extra modules: 100%|██████████| 3.44k/3.44k [00:00<00:00, 2.42MB/s]\n", "Downloading extra modules: 100%|██████████| 60.9k/60.9k [00:00<00:00, 561kB/s]\n", "Downloading builder script: 100%|██████████| 12.8k/12.8k [00:00<00:00, 6.66MB/s]\n", "Downloading readme: 100%|██████████| 11.2k/11.2k [00:00<00:00, 10.0MB/s]\n", "Downloading builder script: 100%|██████████| 26.9k/26.9k [00:00<00:00, 502kB/s]\n", "Downloading metadata: 100%|██████████| 210k/210k [00:00<00:00, 967kB/s] \n", "Downloading readme: 100%|██████████| 42.9k/42.9k [00:00<00:00, 395kB/s]\n", "Downloading builder script: 100%|██████████| 3.98k/3.98k [00:00<00:00, 6.60MB/s]\n", "Downloading readme: 100%|██████████| 5.15k/5.15k [00:00<00:00, 8.64MB/s]\n", "Using custom data configuration ca\n", "Downloading builder script: 100%|██████████| 5.13k/5.13k [00:00<00:00, 8.56MB/s]\n", "Downloading readme: 100%|██████████| 8.64k/8.64k [00:00<00:00, 10.2MB/s]\n" ] } ], "source": [ "trainset = load_multiple_streaming_datasets(dataset_names, dataset_config_names=dataset_config_names, text_column_names=text_column_names, use_auth_token=True)" ] }, { "cell_type": "code", "execution_count": 6, "id": "70d92e1d-bf5b-4fa3-ae15-15ae1e8aba49", "metadata": {}, "outputs": [], "source": [ "testset = IterableDataset\n", "testset = load_dataset(\"mozilla-foundation/common_voice_11_0\", \"ca\", split=\"test\", streaming=True, use_auth_token=True)\n", "testset = testset.cast_column(\"audio\", Audio(sampling_rate=16000))" ] }, { "cell_type": "code", "execution_count": 7, "id": "83fdafea-8850-446a-96ec-334f455511c4", "metadata": {}, "outputs": [], "source": [ "COLUMNS_TO_KEEP = [\"sentence\", \"audio\"]\n", "all_columns = testset.features\n", "columns_to_remove = set(all_columns) - set(COLUMNS_TO_KEEP)\n", "\n", "testset = testset.remove_columns(columns_to_remove)" ] }, { "cell_type": "code", "execution_count": 8, "id": "eec854ee-a38b-45f7-9404-2425ec858ca7", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trainset.features" ] }, { "cell_type": "code", "execution_count": 9, "id": "b0a4ac05-1245-4d4a-9186-34e286a452e0", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "{'audio': Audio(sampling_rate=16000, mono=True, decode=True, id=None),\n", " 'sentence': Value(dtype='string', id=None)}" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "testset.features" ] }, { "cell_type": "code", "execution_count": 10, "id": "213d7970-65fc-4c67-b34f-386927061325", "metadata": {}, "outputs": [], "source": [ "do_lower_case = True\n", "do_remove_punctuation = True\n", "\n", "normalizer = BasicTextNormalizer()" ] }, { "cell_type": "code", "execution_count": 11, "id": "5d28d851-7a0c-4e5d-a346-29932ab1e16d", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: 100%|██████████| 185k/185k [00:00<00:00, 1.68MB/s]\n", "Downloading: 100%|██████████| 830/830 [00:00<00:00, 1.56MB/s]\n", "Downloading: 100%|██████████| 1.04M/1.04M [00:00<00:00, 3.79MB/s]\n", "Downloading: 100%|██████████| 494k/494k [00:00<00:00, 1.82MB/s]\n", "Downloading: 100%|██████████| 52.7k/52.7k [00:00<00:00, 485kB/s]\n", "Downloading: 100%|██████████| 2.11k/2.11k [00:00<00:00, 4.12MB/s]\n", "Downloading: 100%|██████████| 2.06k/2.06k [00:00<00:00, 3.79MB/s]\n" ] } ], "source": [ "processor = WhisperProcessor.from_pretrained(\"openai/whisper-medium\", language=\"Catalan\", task=\"transcribe\")" ] }, { "cell_type": "code", "execution_count": 12, "id": "ca5898c1-390a-4f3b-bf69-00691b0a078f", "metadata": {}, "outputs": [], "source": [ "def prepare_dataset(batch):\n", " # load and (possibly) resample audio data to 16kHz\n", " audio = batch[\"audio\"]\n", "\n", " # compute log-Mel input features from input audio array \n", " batch[\"input_features\"] = processor.feature_extractor(audio[\"array\"], sampling_rate=audio[\"sampling_rate\"]).input_features[0]\n", " # compute input length of audio sample in seconds\n", " batch[\"input_length\"] = len(audio[\"array\"]) / audio[\"sampling_rate\"]\n", " \n", " # optional pre-processing steps\n", " transcription = batch[\"sentence\"]\n", " if do_lower_case:\n", " transcription = transcription.lower()\n", " if do_remove_punctuation:\n", " transcription = normalizer(transcription).strip()\n", " \n", " # encode target text to label ids\n", " batch[\"labels\"] = processor.tokenizer(transcription).input_ids\n", " return batch" ] }, { "cell_type": "code", "execution_count": 13, "id": "70b4d4b2-e6e3-4c37-9607-a47cf27abf99", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = trainset.map(prepare_dataset).with_format(\"torch\")\n", "vectorized_testset = testset.map(prepare_dataset).with_format(\"torch\")" ] }, { "cell_type": "code", "execution_count": 14, "id": "42873264-27d0-43db-b0ac-5b40b915dee6", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.shuffle( buffer_size=500,seed=0,)\n", "vectorized_testset = vectorized_testset.shuffle( buffer_size=500,seed=0,)" ] }, { "cell_type": "code", "execution_count": 15, "id": "0a64bcf6-3262-49b4-bbfb-d5b5ed59b395", "metadata": {}, "outputs": [], "source": [ "MAX_DURATION_IN_SECONDS = 30.0\n", "\n", "def is_audio_length_in_range(input_length):\n", " return input_length < MAX_DURATION_IN_SECONDS" ] }, { "cell_type": "code", "execution_count": 16, "id": "bfa12b9e-a8ac-42fa-8009-4105b4f1d346", "metadata": {}, "outputs": [], "source": [ "vectorized_trainset = vectorized_trainset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])\n", "vectorized_testset = vectorized_testset.filter(is_audio_length_in_range, input_columns=[\"input_length\"])" ] }, { "cell_type": "code", "execution_count": 17, "id": "fd54b51d-b2e3-4e09-97db-40907126f178", "metadata": {}, "outputs": [], "source": [ "import torch\n", "\n", "from dataclasses import dataclass\n", "from typing import Any, Dict, List, Union\n", "\n", "@dataclass\n", "class DataCollatorSpeechSeq2SeqWithPadding:\n", " processor: Any\n", "\n", " def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:\n", " # split inputs and labels since they have to be of different lengths and need different padding methods\n", " # first treat the audio inputs by simply returning torch tensors\n", " input_features = [{\"input_features\": feature[\"input_features\"]} for feature in features]\n", " batch = self.processor.feature_extractor.pad(input_features, return_tensors=\"pt\")\n", "\n", " # get the tokenized label sequences\n", " label_features = [{\"input_ids\": feature[\"labels\"]} for feature in features]\n", " # pad the labels to max length\n", " labels_batch = self.processor.tokenizer.pad(label_features, return_tensors=\"pt\")\n", "\n", " # replace padding with -100 to ignore loss correctly\n", " labels = labels_batch[\"input_ids\"].masked_fill(labels_batch.attention_mask.ne(1), -100)\n", "\n", " # if bos token is appended in previous tokenization step,\n", " # cut bos token here as it's append later anyways\n", " if (labels[:, 0] == self.processor.tokenizer.bos_token_id).all().cpu().item():\n", " labels = labels[:, 1:]\n", "\n", " batch[\"labels\"] = labels\n", "\n", " return batch" ] }, { "cell_type": "code", "execution_count": 18, "id": "17691122-25f0-4302-ada8-0f01dc1a29d9", "metadata": {}, "outputs": [], "source": [ "data_collator = DataCollatorSpeechSeq2SeqWithPadding(processor=processor)" ] }, { "cell_type": "code", "execution_count": 19, "id": "66d06025-fe27-4452-9d53-a8afedef3cbf", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading builder script: 100%|██████████| 4.49k/4.49k [00:00<00:00, 7.30MB/s]\n" ] } ], "source": [ "import evaluate\n", "\n", "metric = evaluate.load(\"wer\")" ] }, { "cell_type": "code", "execution_count": 20, "id": "f0607b5e-3c18-4347-b59d-d810d7d2b572", "metadata": {}, "outputs": [], "source": [ "# evaluate with the 'normalised' WER\n", "do_normalize_eval = True\n", "\n", "def compute_metrics(pred):\n", " pred_ids = pred.predictions\n", " label_ids = pred.label_ids\n", "\n", " # replace -100 with the pad_token_id\n", " label_ids[label_ids == -100] = processor.tokenizer.pad_token_id\n", "\n", " # we do not want to group tokens when computing the metrics\n", " pred_str = processor.tokenizer.batch_decode(pred_ids, skip_special_tokens=True)\n", " label_str = processor.tokenizer.batch_decode(label_ids, skip_special_tokens=True)\n", "\n", " if do_normalize_eval:\n", " pred_str = [normalizer(pred) for pred in pred_str]\n", " label_str = [normalizer(label) for label in label_str]\n", "\n", " wer = 100 * metric.compute(predictions=pred_str, references=label_str)\n", "\n", " return {\"wer\": wer}" ] }, { "cell_type": "code", "execution_count": 21, "id": "f90b023f-0988-4c94-aea9-68f3d59dce47", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Downloading: 100%|██████████| 1.97k/1.97k [00:00<00:00, 3.60MB/s]\n", "Downloading: 100%|██████████| 3.06G/3.06G [00:35<00:00, 85.2MB/s]\n" ] } ], "source": [ "from transformers import WhisperForConditionalGeneration\n", "\n", "model = WhisperForConditionalGeneration.from_pretrained(\"openai/whisper-medium\")" ] }, { "cell_type": "code", "execution_count": 26, "id": "6fd383ff-9911-4fab-8d3f-9711baa38993", "metadata": {}, "outputs": [], "source": [ "model.config.forced_decoder_ids = None\n", "model.config.suppress_tokens = []\n", "model.config.use_cache = False" ] }, { "cell_type": "code", "execution_count": 31, "id": "0d960bd6-fa06-4d5c-b322-b754edc1e69b", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "PyTorch: setting up devices\n" ] } ], "source": [ "from transformers import Seq2SeqTrainingArguments\n", "\n", "training_args = Seq2SeqTrainingArguments(\n", " output_dir=\"./\",\n", " per_device_train_batch_size=32,\n", " gradient_accumulation_steps=2, # increase by 2x for every 2x decrease in batch size\n", " learning_rate=1e-5,\n", " warmup_steps=1000,\n", " max_steps=10000,\n", " gradient_checkpointing=True,\n", " fp16=True,\n", " evaluation_strategy=\"steps\",\n", " per_device_eval_batch_size=8,\n", " predict_with_generate=True,\n", " generation_max_length=225,\n", " save_steps=1000,\n", " eval_steps=1000,\n", " logging_steps=25,\n", " report_to=[\"tensorboard\"],\n", " load_best_model_at_end=True,\n", " metric_for_best_model=\"wer\",\n", " greater_is_better=False,\n", " push_to_hub=True,\n", ")" ] }, { "cell_type": "code", "execution_count": 32, "id": "47b13715-e6d6-48b5-8296-58ec999e476a", "metadata": {}, "outputs": [], "source": [ "from transformers import TrainerCallback\n", "from transformers.trainer_pt_utils import IterableDatasetShard\n", "from torch.utils.data import IterableDataset\n", "\n", "# trainer callback to reinitialise and reshuffle the streamable datasets at the beginning of each epoch\n", "class ShuffleCallback(TrainerCallback):\n", " def on_epoch_begin(self, args, state, control, train_dataloader, **kwargs):\n", " if isinstance(train_dataloader.dataset, IterableDatasetShard):\n", " pass # set_epoch() is handled by the Trainer\n", " elif isinstance(train_dataloader.dataset, IterableDataset):\n", " train_dataloader.dataset.set_epoch(train_dataloader.dataset._epoch + 1)" ] }, { "cell_type": "code", "execution_count": 33, "id": "74a19d5b-dde9-448c-a0c3-ff289816ee5f", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/whisper-medium-ca/./ is already a clone of https://huggingface.co/JulioCastro/whisper-medium-ca. Make sure you pull the latest changes with `repo.git_pull()`.\n", "max_steps is given, it will override any value given in num_train_epochs\n", "Using cuda_amp half precision backend\n" ] } ], "source": [ "from transformers import Seq2SeqTrainer\n", "\n", "trainer = Seq2SeqTrainer(\n", " args=training_args,\n", " model=model,\n", " train_dataset=vectorized_trainset,\n", " eval_dataset=vectorized_testset,\n", " data_collator=data_collator,\n", " compute_metrics=compute_metrics,\n", " tokenizer=processor,\n", " callbacks=[ShuffleCallback()],\n", ")" ] }, { "cell_type": "code", "execution_count": 34, "id": "95e3ad06-96cc-4c1a-a45c-765796dbfcdc", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "Configuration saved in ./config.json\n", "Model weights saved in ./pytorch_model.bin\n", "Feature extractor saved in ./preprocessor_config.json\n", "tokenizer config file saved in ./tokenizer_config.json\n", "Special tokens file saved in ./special_tokens_map.json\n", "added tokens file saved in ./added_tokens.json\n" ] } ], "source": [ "model.save_pretrained(training_args.output_dir)\n", "processor.save_pretrained(training_args.output_dir)" ] }, { "cell_type": "code", "execution_count": 35, "id": "e4ce26c9-4737-4565-96b1-6da0d99ce084", "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/home/ubuntu/hf_env/lib/python3.8/site-packages/transformers/optimization.py:306: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n", " warnings.warn(\n", "***** Running training *****\n", " Num examples = 640000\n", " Num Epochs = 9223372036854775807\n", " Instantaneous batch size per device = 32\n", " Total train batch size (w. parallel, distributed & accumulation) = 64\n", " Gradient Accumulation steps = 2\n", " Total optimization steps = 10000\n", " Number of trainable parameters = 763857920\n", "Reading metadata...: 905243it [00:14, 60880.41it/s]\n", "The following columns in the training set don't have a corresponding argument in `WhisperForConditionalGeneration.forward` and have been ignored: audio, input_length, sentence. If audio, input_length, sentence are not expected by `WhisperForConditionalGeneration.forward`, you can safely ignore this message.\n" ] }, { "data": { "text/html": [ "\n", "
Step | \n", "Training Loss | \n", "Validation Loss | \n", "Wer | \n", "
---|---|---|---|
1000 | \n", "0.135000 | \n", "0.226110 | \n", "12.893873 | \n", "
2000 | \n", "0.103200 | \n", "0.190505 | \n", "10.003139 | \n", "
"
],
"text/plain": [
"