Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
cfbd4d5
1
Parent(s):
de9c8c4
Update preprocessing script to use logging module
Browse files- src/preprocess.py +28 -27
src/preprocess.py
CHANGED
@@ -20,6 +20,9 @@ import time
|
|
20 |
import requests
|
21 |
|
22 |
|
|
|
|
|
|
|
23 |
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
24 |
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
25 |
|
@@ -204,7 +207,7 @@ def get_words(video_id, process=True, transcript_type='auto', fallback='manual',
|
|
204 |
pass # Mark as empty transcript
|
205 |
|
206 |
except json.decoder.JSONDecodeError:
|
207 |
-
|
208 |
if os.path.exists(transcript_path):
|
209 |
os.remove(transcript_path) # Remove file and try again
|
210 |
return get_words(video_id, process, transcript_type, fallback, granularity)
|
@@ -543,12 +546,12 @@ def main():
|
|
543 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
544 |
|
545 |
if preprocess_args.update_database:
|
546 |
-
|
547 |
for mirror in MIRRORS:
|
548 |
-
|
549 |
if download_file(mirror, raw_dataset_path):
|
550 |
break
|
551 |
-
|
552 |
|
553 |
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
554 |
processed_db_path = os.path.join(
|
@@ -558,11 +561,10 @@ def main():
|
|
558 |
@lru_cache(maxsize=1)
|
559 |
def read_db():
|
560 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
561 |
-
|
562 |
-
'Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
563 |
with open(processed_db_path) as fp:
|
564 |
return json.load(fp)
|
565 |
-
|
566 |
db = {}
|
567 |
|
568 |
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
@@ -618,7 +620,7 @@ def main():
|
|
618 |
|
619 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
620 |
if not preprocess_args.keep_duplicate_segments:
|
621 |
-
|
622 |
for key in db:
|
623 |
db[key] = remove_duplicate_segments(db[key])
|
624 |
|
@@ -646,7 +648,7 @@ def main():
|
|
646 |
|
647 |
# TODO remove videos that contain a full-video label?
|
648 |
|
649 |
-
|
650 |
|
651 |
with open(processed_db_path, 'w') as fp:
|
652 |
json.dump(db, fp)
|
@@ -660,7 +662,7 @@ def main():
|
|
660 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
661 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
662 |
if preprocess_args.do_transcribe:
|
663 |
-
|
664 |
parsed_database = read_db()
|
665 |
|
666 |
# Remove transcripts already processed
|
@@ -678,7 +680,7 @@ def main():
|
|
678 |
get_words(video_id)
|
679 |
return video_id
|
680 |
|
681 |
-
|
682 |
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
683 |
tqdm(total=len(video_ids)) as progress:
|
684 |
|
@@ -698,21 +700,21 @@ def main():
|
|
698 |
progress.update()
|
699 |
|
700 |
except KeyboardInterrupt:
|
701 |
-
|
702 |
|
703 |
# only futures that are not done will prevent exiting
|
704 |
for future in to_process:
|
705 |
future.cancel()
|
706 |
|
707 |
-
|
708 |
concurrent.futures.wait(to_process, timeout=None)
|
709 |
-
|
710 |
|
711 |
final_path = os.path.join(
|
712 |
dataset_args.data_dir, dataset_args.processed_file)
|
713 |
|
714 |
if preprocess_args.do_create:
|
715 |
-
|
716 |
|
717 |
final_data = {}
|
718 |
|
@@ -786,7 +788,7 @@ def main():
|
|
786 |
dataset_args.data_dir, dataset_args.negative_file)
|
787 |
|
788 |
if preprocess_args.do_generate:
|
789 |
-
|
790 |
# max_videos=preprocess_args.max_videos,
|
791 |
# max_segments=preprocess_args.max_segments,
|
792 |
# , max_videos, max_segments
|
@@ -868,8 +870,8 @@ def main():
|
|
868 |
print(json.dumps(d), file=negative)
|
869 |
|
870 |
if preprocess_args.do_split:
|
871 |
-
|
872 |
-
|
873 |
|
874 |
with open(positive_file, encoding='utf-8') as positive:
|
875 |
sponsors = positive.readlines()
|
@@ -877,11 +879,11 @@ def main():
|
|
877 |
with open(negative_file, encoding='utf-8') as negative:
|
878 |
non_sponsors = negative.readlines()
|
879 |
|
880 |
-
|
881 |
random.shuffle(sponsors)
|
882 |
random.shuffle(non_sponsors)
|
883 |
|
884 |
-
|
885 |
# Ensure correct ratio of positive to negative segments
|
886 |
percentage_negative = 1 - preprocess_args.percentage_positive
|
887 |
|
@@ -901,12 +903,12 @@ def main():
|
|
901 |
excess = non_sponsors[z:]
|
902 |
non_sponsors = non_sponsors[:z]
|
903 |
|
904 |
-
|
905 |
all_labelled_segments = sponsors + non_sponsors
|
906 |
|
907 |
random.shuffle(all_labelled_segments)
|
908 |
|
909 |
-
|
910 |
ratios = [preprocess_args.train_split,
|
911 |
preprocess_args.test_split,
|
912 |
preprocess_args.valid_split]
|
@@ -927,9 +929,9 @@ def main():
|
|
927 |
with open(outfile, 'w', encoding='utf-8') as fp:
|
928 |
fp.writelines(items)
|
929 |
else:
|
930 |
-
|
931 |
|
932 |
-
|
933 |
# Save excess items
|
934 |
excess_path = os.path.join(
|
935 |
dataset_args.data_dir, dataset_args.excess_file)
|
@@ -937,10 +939,9 @@ def main():
|
|
937 |
with open(excess_path, 'w', encoding='utf-8') as fp:
|
938 |
fp.writelines(excess)
|
939 |
else:
|
940 |
-
|
941 |
|
942 |
-
|
943 |
-
'sponsors,', len(non_sponsors), 'non sponsors')
|
944 |
|
945 |
|
946 |
def split(arr, ratios):
|
|
|
20 |
import requests
|
21 |
|
22 |
|
23 |
+
logger = logging.getLogger(__name__)
|
24 |
+
|
25 |
+
|
26 |
PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
|
27 |
PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
|
28 |
|
|
|
207 |
pass # Mark as empty transcript
|
208 |
|
209 |
except json.decoder.JSONDecodeError:
|
210 |
+
logger.warning(f'JSONDecodeError for {video_id}')
|
211 |
if os.path.exists(transcript_path):
|
212 |
os.remove(transcript_path) # Remove file and try again
|
213 |
return get_words(video_id, process, transcript_type, fallback, granularity)
|
|
|
546 |
preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
|
547 |
|
548 |
if preprocess_args.update_database:
|
549 |
+
logger.info('Updating database')
|
550 |
for mirror in MIRRORS:
|
551 |
+
logger.info(f'Downloading from {mirror}')
|
552 |
if download_file(mirror, raw_dataset_path):
|
553 |
break
|
554 |
+
logger.warning('Failed, trying next')
|
555 |
|
556 |
os.makedirs(dataset_args.data_dir, exist_ok=True)
|
557 |
processed_db_path = os.path.join(
|
|
|
561 |
@lru_cache(maxsize=1)
|
562 |
def read_db():
|
563 |
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
564 |
+
logger.info('Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
|
|
565 |
with open(processed_db_path) as fp:
|
566 |
return json.load(fp)
|
567 |
+
logger.info('Processing raw database')
|
568 |
db = {}
|
569 |
|
570 |
allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
|
|
|
620 |
|
621 |
# Remove duplicate sponsor segments by choosing best (most votes)
|
622 |
if not preprocess_args.keep_duplicate_segments:
|
623 |
+
logger.info('Remove duplicate segments')
|
624 |
for key in db:
|
625 |
db[key] = remove_duplicate_segments(db[key])
|
626 |
|
|
|
648 |
|
649 |
# TODO remove videos that contain a full-video label?
|
650 |
|
651 |
+
logger.info(f'Saved {len(db)} videos')
|
652 |
|
653 |
with open(processed_db_path, 'w') as fp:
|
654 |
json.dump(db, fp)
|
|
|
662 |
# 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
|
663 |
# 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
|
664 |
if preprocess_args.do_transcribe:
|
665 |
+
logger.info('Collecting videos')
|
666 |
parsed_database = read_db()
|
667 |
|
668 |
# Remove transcripts already processed
|
|
|
680 |
get_words(video_id)
|
681 |
return video_id
|
682 |
|
683 |
+
logger.info('Setting up ThreadPoolExecutor')
|
684 |
with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
|
685 |
tqdm(total=len(video_ids)) as progress:
|
686 |
|
|
|
700 |
progress.update()
|
701 |
|
702 |
except KeyboardInterrupt:
|
703 |
+
logger.info('Gracefully shutting down: Cancelling unscheduled tasks')
|
704 |
|
705 |
# only futures that are not done will prevent exiting
|
706 |
for future in to_process:
|
707 |
future.cancel()
|
708 |
|
709 |
+
logger.info('Waiting for in-progress tasks to complete')
|
710 |
concurrent.futures.wait(to_process, timeout=None)
|
711 |
+
logger.info('Cancellation successful')
|
712 |
|
713 |
final_path = os.path.join(
|
714 |
dataset_args.data_dir, dataset_args.processed_file)
|
715 |
|
716 |
if preprocess_args.do_create:
|
717 |
+
logger.info('Create final data')
|
718 |
|
719 |
final_data = {}
|
720 |
|
|
|
788 |
dataset_args.data_dir, dataset_args.negative_file)
|
789 |
|
790 |
if preprocess_args.do_generate:
|
791 |
+
logger.info('Generating')
|
792 |
# max_videos=preprocess_args.max_videos,
|
793 |
# max_segments=preprocess_args.max_segments,
|
794 |
# , max_videos, max_segments
|
|
|
870 |
print(json.dumps(d), file=negative)
|
871 |
|
872 |
if preprocess_args.do_split:
|
873 |
+
logger.info('Splitting')
|
874 |
+
logger.info('Read files')
|
875 |
|
876 |
with open(positive_file, encoding='utf-8') as positive:
|
877 |
sponsors = positive.readlines()
|
|
|
879 |
with open(negative_file, encoding='utf-8') as negative:
|
880 |
non_sponsors = negative.readlines()
|
881 |
|
882 |
+
logger.info('Shuffle')
|
883 |
random.shuffle(sponsors)
|
884 |
random.shuffle(non_sponsors)
|
885 |
|
886 |
+
logger.info('Calculate ratios')
|
887 |
# Ensure correct ratio of positive to negative segments
|
888 |
percentage_negative = 1 - preprocess_args.percentage_positive
|
889 |
|
|
|
903 |
excess = non_sponsors[z:]
|
904 |
non_sponsors = non_sponsors[:z]
|
905 |
|
906 |
+
logger.info('Join')
|
907 |
all_labelled_segments = sponsors + non_sponsors
|
908 |
|
909 |
random.shuffle(all_labelled_segments)
|
910 |
|
911 |
+
logger.info('Split')
|
912 |
ratios = [preprocess_args.train_split,
|
913 |
preprocess_args.test_split,
|
914 |
preprocess_args.valid_split]
|
|
|
929 |
with open(outfile, 'w', encoding='utf-8') as fp:
|
930 |
fp.writelines(items)
|
931 |
else:
|
932 |
+
logger.info(f'Skipping {name}')
|
933 |
|
934 |
+
logger.info('Write')
|
935 |
# Save excess items
|
936 |
excess_path = os.path.join(
|
937 |
dataset_args.data_dir, dataset_args.excess_file)
|
|
|
939 |
with open(excess_path, 'w', encoding='utf-8') as fp:
|
940 |
fp.writelines(excess)
|
941 |
else:
|
942 |
+
logger.info(f'Skipping {dataset_args.excess_file}')
|
943 |
|
944 |
+
logger.info(f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
|
|
|
945 |
|
946 |
|
947 |
def split(arr, ratios):
|