Commit
·
1d35cf4
1
Parent(s):
e51d4c5
remove columns not in common across datasets
Browse files- run_speech_recognition_ctc.py +28 -13
run_speech_recognition_ctc.py
CHANGED
@@ -331,7 +331,7 @@ def create_vocabulary_from_data(
|
|
331 |
batched=True,
|
332 |
batch_size=10000,
|
333 |
keep_in_memory=False,
|
334 |
-
|
335 |
)
|
336 |
|
337 |
# take union of all unique characters in each dataset
|
@@ -418,6 +418,11 @@ def main():
|
|
418 |
# 1. First, let's load the dataset
|
419 |
raw_datasets = DatasetDict()
|
420 |
|
|
|
|
|
|
|
|
|
|
|
421 |
if training_args.do_train:
|
422 |
|
423 |
# Multiple datasets might need to be loaded from HF
|
@@ -437,18 +442,21 @@ def main():
|
|
437 |
split=train_split_name,
|
438 |
use_auth_token=data_args.use_auth_token,
|
439 |
)
|
|
|
440 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
441 |
raw_datasets["train"] = concatenate_datasets(
|
442 |
[
|
443 |
raw_datasets["train"],
|
444 |
-
|
445 |
-
dataset_name,
|
446 |
-
dataset_config_name,
|
447 |
-
split=train_split_name,
|
448 |
-
use_auth_token=data_args.use_auth_token,
|
449 |
-
)
|
450 |
]
|
451 |
)
|
|
|
452 |
else:
|
453 |
logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
|
454 |
|
@@ -468,6 +476,8 @@ def main():
|
|
468 |
|
469 |
if data_args.max_train_samples is not None:
|
470 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
|
|
|
|
471 |
|
472 |
if training_args.do_eval:
|
473 |
# Multiple datasets might need to be loaded from HF
|
@@ -486,23 +496,28 @@ def main():
|
|
486 |
split=eval_split_name,
|
487 |
use_auth_token=data_args.use_auth_token,
|
488 |
)
|
|
|
489 |
else:
|
|
|
|
|
|
|
|
|
|
|
|
|
490 |
raw_datasets["eval"] = concatenate_datasets(
|
491 |
[
|
492 |
raw_datasets["eval"],
|
493 |
-
|
494 |
-
dataset_name,
|
495 |
-
dataset_config_name,
|
496 |
-
split=eval_split_name,
|
497 |
-
use_auth_token=data_args.use_auth_token,
|
498 |
-
)
|
499 |
]
|
500 |
)
|
|
|
501 |
else:
|
502 |
logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
|
503 |
|
504 |
if data_args.max_eval_samples is not None:
|
505 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
|
|
|
|
506 |
|
507 |
# 2. We remove some special characters from the datasets
|
508 |
# that make training complicated and do not help in transcribing the speech
|
|
|
331 |
batched=True,
|
332 |
batch_size=10000,
|
333 |
keep_in_memory=False,
|
334 |
+
remove_columns=datasets["train"].column_names,
|
335 |
)
|
336 |
|
337 |
# take union of all unique characters in each dataset
|
|
|
418 |
# 1. First, let's load the dataset
|
419 |
raw_datasets = DatasetDict()
|
420 |
|
421 |
+
def common_cols(dataset_a, dataset_b):
|
422 |
+
col_a = set(dataset_a.column_names)
|
423 |
+
col_b = set(dataset_b.column_names)
|
424 |
+
return [col for col in col_a if col in col_b]
|
425 |
+
|
426 |
if training_args.do_train:
|
427 |
|
428 |
# Multiple datasets might need to be loaded from HF
|
|
|
442 |
split=train_split_name,
|
443 |
use_auth_token=data_args.use_auth_token,
|
444 |
)
|
445 |
+
min_columns_train = raw_datasets["train"].column_names
|
446 |
else:
|
447 |
+
new_dataset = load_dataset(
|
448 |
+
dataset_name,
|
449 |
+
dataset_config_name,
|
450 |
+
split=train_split_name,
|
451 |
+
use_auth_token=data_args.use_auth_token,
|
452 |
+
)
|
453 |
raw_datasets["train"] = concatenate_datasets(
|
454 |
[
|
455 |
raw_datasets["train"],
|
456 |
+
new_dataset
|
|
|
|
|
|
|
|
|
|
|
457 |
]
|
458 |
)
|
459 |
+
min_columns_train = common_cols(min_columns, new_dataset.column_names)
|
460 |
else:
|
461 |
logging.warning(f"{dataset_name} {dataset_config_name} as split is {train_split_name}")
|
462 |
|
|
|
476 |
|
477 |
if data_args.max_train_samples is not None:
|
478 |
raw_datasets["train"] = raw_datasets["train"].select(range(data_args.max_train_samples))
|
479 |
+
other_columns_train = [col for col in raw_datasets["train"].column_names if col not in min_columns_train]
|
480 |
+
raw_datasets["train"].remove_columns(other_columns_train)
|
481 |
|
482 |
if training_args.do_eval:
|
483 |
# Multiple datasets might need to be loaded from HF
|
|
|
496 |
split=eval_split_name,
|
497 |
use_auth_token=data_args.use_auth_token,
|
498 |
)
|
499 |
+
min_columns_eval = raw_datasets["eval"].column_names
|
500 |
else:
|
501 |
+
new_dataset = load_dataset(
|
502 |
+
dataset_name,
|
503 |
+
dataset_config_name,
|
504 |
+
split=eval_split_name,
|
505 |
+
use_auth_token=data_args.use_auth_token,
|
506 |
+
)
|
507 |
raw_datasets["eval"] = concatenate_datasets(
|
508 |
[
|
509 |
raw_datasets["eval"],
|
510 |
+
new_dataset
|
|
|
|
|
|
|
|
|
|
|
511 |
]
|
512 |
)
|
513 |
+
min_columns_eval = common_cols(min_columns_eval, new_dataset.column_names)
|
514 |
else:
|
515 |
logging.warning(f"{dataset_name} {dataset_config_name} as split is {eval_split_name}")
|
516 |
|
517 |
if data_args.max_eval_samples is not None:
|
518 |
raw_datasets["eval"] = raw_datasets["eval"].select(range(data_args.max_eval_samples))
|
519 |
+
other_columns_eval = [col for col in raw_datasets["eval"].column_names if col not in min_columns_eval]
|
520 |
+
raw_datasets["eval"].remove_columns(other_columns_eval)
|
521 |
|
522 |
# 2. We remove some special characters from the datasets
|
523 |
# that make training complicated and do not help in transcribing the speech
|