Commit
·
fcff61b
1
Parent(s):
eec3f65
Add bandaid for empty strings
Browse files
run_speech_recognition_seq2seq.py
CHANGED
@@ -46,7 +46,6 @@ from transformers.trainer_utils import get_last_checkpoint, is_main_process
|
|
46 |
from transformers.utils import check_min_version
|
47 |
from transformers.utils.versions import require_version
|
48 |
|
49 |
-
|
50 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
51 |
check_min_version("4.17.0.dev0")
|
52 |
|
@@ -89,7 +88,7 @@ class ModelArguments:
|
|
89 |
default=False,
|
90 |
metadata={
|
91 |
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
92 |
-
|
93 |
},
|
94 |
)
|
95 |
freeze_feature_encoder: bool = field(
|
@@ -124,14 +123,14 @@ class DataTrainingArguments:
|
|
124 |
default=None,
|
125 |
metadata={
|
126 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
127 |
-
|
128 |
},
|
129 |
)
|
130 |
max_eval_samples: Optional[int] = field(
|
131 |
default=None,
|
132 |
metadata={
|
133 |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
134 |
-
|
135 |
},
|
136 |
)
|
137 |
audio_column_name: str = field(
|
@@ -155,9 +154,9 @@ class DataTrainingArguments:
|
|
155 |
default=False,
|
156 |
metadata={
|
157 |
"help": "Whether to only do data preprocessing and skip training. "
|
158 |
-
|
159 |
-
|
160 |
-
|
161 |
},
|
162 |
)
|
163 |
train_split_name: str = field(
|
@@ -283,12 +282,14 @@ def main():
|
|
283 |
|
284 |
if training_args.do_train:
|
285 |
raw_datasets["train"] = load_dataset(
|
286 |
-
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name,
|
|
|
287 |
)
|
288 |
|
289 |
if training_args.do_eval:
|
290 |
raw_datasets["eval"] = load_dataset(
|
291 |
-
data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name,
|
|
|
292 |
)
|
293 |
|
294 |
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
@@ -378,6 +379,8 @@ def main():
|
|
378 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
379 |
|
380 |
input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
|
|
|
|
|
381 |
|
382 |
batch["labels"] = tokenizer(input_str).input_ids
|
383 |
return batch
|
|
|
46 |
from transformers.utils import check_min_version
|
47 |
from transformers.utils.versions import require_version
|
48 |
|
|
|
49 |
# Will error if the minimal version of Transformers is not installed. Remove at your own risks.
|
50 |
check_min_version("4.17.0.dev0")
|
51 |
|
|
|
88 |
default=False,
|
89 |
metadata={
|
90 |
"help": "Will use the token generated when running `transformers-cli login` (necessary to use this script "
|
91 |
+
"with private models)."
|
92 |
},
|
93 |
)
|
94 |
freeze_feature_encoder: bool = field(
|
|
|
123 |
default=None,
|
124 |
metadata={
|
125 |
"help": "For debugging purposes or quicker training, truncate the number of training examples to this "
|
126 |
+
"value if set."
|
127 |
},
|
128 |
)
|
129 |
max_eval_samples: Optional[int] = field(
|
130 |
default=None,
|
131 |
metadata={
|
132 |
"help": "For debugging purposes or quicker training, truncate the number of evaluation examples to this "
|
133 |
+
"value if set."
|
134 |
},
|
135 |
)
|
136 |
audio_column_name: str = field(
|
|
|
154 |
default=False,
|
155 |
metadata={
|
156 |
"help": "Whether to only do data preprocessing and skip training. "
|
157 |
+
"This is especially useful when data preprocessing errors out in distributed training due to timeout. "
|
158 |
+
"In this case, one should run the preprocessing in a non-distributed setup with `preprocessing_only=True` "
|
159 |
+
"so that the cached datasets can consequently be loaded in distributed training"
|
160 |
},
|
161 |
)
|
162 |
train_split_name: str = field(
|
|
|
282 |
|
283 |
if training_args.do_train:
|
284 |
raw_datasets["train"] = load_dataset(
|
285 |
+
data_args.dataset_name, data_args.dataset_config_name, split=data_args.train_split_name,
|
286 |
+
cache_dir=data_args.data_cache_dir
|
287 |
)
|
288 |
|
289 |
if training_args.do_eval:
|
290 |
raw_datasets["eval"] = load_dataset(
|
291 |
+
data_args.dataset_name, data_args.dataset_config_name, split=data_args.eval_split_name,
|
292 |
+
cache_dir=data_args.data_cache_dir
|
293 |
)
|
294 |
|
295 |
if data_args.audio_column_name not in next(iter(raw_datasets.values())).column_names:
|
|
|
379 |
input_str = batch[text_column_name].lower() if do_lower_case else batch[text_column_name]
|
380 |
|
381 |
input_str = re.sub(r"<\*?(ee|qq|mm|inaudible)>", "", input_str, re.IGNORECASE)
|
382 |
+
if len(input_str) == 0:
|
383 |
+
input_str = "." # bandaid
|
384 |
|
385 |
batch["labels"] = tokenizer(input_str).input_ids
|
386 |
return batch
|