Joshua Lochner commited on
Commit
cfbd4d5
·
1 Parent(s): de9c8c4

Update preprocessing script to use logging module

Browse files
Files changed (1) hide show
  1. src/preprocess.py +28 -27
src/preprocess.py CHANGED
@@ -20,6 +20,9 @@ import time
20
  import requests
21
 
22
 
 
 
 
23
  PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
24
  PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
25
 
@@ -204,7 +207,7 @@ def get_words(video_id, process=True, transcript_type='auto', fallback='manual',
204
  pass # Mark as empty transcript
205
 
206
  except json.decoder.JSONDecodeError:
207
- print('JSONDecodeError for', video_id)
208
  if os.path.exists(transcript_path):
209
  os.remove(transcript_path) # Remove file and try again
210
  return get_words(video_id, process, transcript_type, fallback, granularity)
@@ -543,12 +546,12 @@ def main():
543
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
544
 
545
  if preprocess_args.update_database:
546
- print('Updating database')
547
  for mirror in MIRRORS:
548
- print('Downloading from', mirror)
549
  if download_file(mirror, raw_dataset_path):
550
  break
551
- print('Failed, trying next')
552
 
553
  os.makedirs(dataset_args.data_dir, exist_ok=True)
554
  processed_db_path = os.path.join(
@@ -558,11 +561,10 @@ def main():
558
  @lru_cache(maxsize=1)
559
  def read_db():
560
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
561
- print(
562
- 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
563
  with open(processed_db_path) as fp:
564
  return json.load(fp)
565
- print('Processing raw database')
566
  db = {}
567
 
568
  allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
@@ -618,7 +620,7 @@ def main():
618
 
619
  # Remove duplicate sponsor segments by choosing best (most votes)
620
  if not preprocess_args.keep_duplicate_segments:
621
- print('Remove duplicate segments')
622
  for key in db:
623
  db[key] = remove_duplicate_segments(db[key])
624
 
@@ -646,7 +648,7 @@ def main():
646
 
647
  # TODO remove videos that contain a full-video label?
648
 
649
- print('Saved', len(db), 'videos')
650
 
651
  with open(processed_db_path, 'w') as fp:
652
  json.dump(db, fp)
@@ -660,7 +662,7 @@ def main():
660
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
661
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
662
  if preprocess_args.do_transcribe:
663
- print('Collecting videos')
664
  parsed_database = read_db()
665
 
666
  # Remove transcripts already processed
@@ -678,7 +680,7 @@ def main():
678
  get_words(video_id)
679
  return video_id
680
 
681
- print('Setting up ThreadPoolExecutor')
682
  with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
683
  tqdm(total=len(video_ids)) as progress:
684
 
@@ -698,21 +700,21 @@ def main():
698
  progress.update()
699
 
700
  except KeyboardInterrupt:
701
- print('Gracefully shutting down: Cancelling unscheduled tasks')
702
 
703
  # only futures that are not done will prevent exiting
704
  for future in to_process:
705
  future.cancel()
706
 
707
- print('Waiting for in-progress tasks to complete')
708
  concurrent.futures.wait(to_process, timeout=None)
709
- print('Cancellation successful')
710
 
711
  final_path = os.path.join(
712
  dataset_args.data_dir, dataset_args.processed_file)
713
 
714
  if preprocess_args.do_create:
715
- print('Create final data')
716
 
717
  final_data = {}
718
 
@@ -786,7 +788,7 @@ def main():
786
  dataset_args.data_dir, dataset_args.negative_file)
787
 
788
  if preprocess_args.do_generate:
789
- print('Generating')
790
  # max_videos=preprocess_args.max_videos,
791
  # max_segments=preprocess_args.max_segments,
792
  # , max_videos, max_segments
@@ -868,8 +870,8 @@ def main():
868
  print(json.dumps(d), file=negative)
869
 
870
  if preprocess_args.do_split:
871
- print('Splitting')
872
- print('Read files')
873
 
874
  with open(positive_file, encoding='utf-8') as positive:
875
  sponsors = positive.readlines()
@@ -877,11 +879,11 @@ def main():
877
  with open(negative_file, encoding='utf-8') as negative:
878
  non_sponsors = negative.readlines()
879
 
880
- print('Shuffle')
881
  random.shuffle(sponsors)
882
  random.shuffle(non_sponsors)
883
 
884
- print('Calculate ratios')
885
  # Ensure correct ratio of positive to negative segments
886
  percentage_negative = 1 - preprocess_args.percentage_positive
887
 
@@ -901,12 +903,12 @@ def main():
901
  excess = non_sponsors[z:]
902
  non_sponsors = non_sponsors[:z]
903
 
904
- print('Join')
905
  all_labelled_segments = sponsors + non_sponsors
906
 
907
  random.shuffle(all_labelled_segments)
908
 
909
- print('Split')
910
  ratios = [preprocess_args.train_split,
911
  preprocess_args.test_split,
912
  preprocess_args.valid_split]
@@ -927,9 +929,9 @@ def main():
927
  with open(outfile, 'w', encoding='utf-8') as fp:
928
  fp.writelines(items)
929
  else:
930
- print('Skipping', name)
931
 
932
- print('Write')
933
  # Save excess items
934
  excess_path = os.path.join(
935
  dataset_args.data_dir, dataset_args.excess_file)
@@ -937,10 +939,9 @@ def main():
937
  with open(excess_path, 'w', encoding='utf-8') as fp:
938
  fp.writelines(excess)
939
  else:
940
- print('Skipping', dataset_args.excess_file)
941
 
942
- print('Finished splitting:', len(sponsors),
943
- 'sponsors,', len(non_sponsors), 'non sponsors')
944
 
945
 
946
  def split(arr, ratios):
 
20
  import requests
21
 
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
  PROFANITY_RAW = '[ __ ]' # How YouTube transcribes profanity
27
  PROFANITY_CONVERTED = '*****' # Safer version for tokenizing
28
 
 
207
  pass # Mark as empty transcript
208
 
209
  except json.decoder.JSONDecodeError:
210
+ logger.warning(f'JSONDecodeError for {video_id}')
211
  if os.path.exists(transcript_path):
212
  os.remove(transcript_path) # Remove file and try again
213
  return get_words(video_id, process, transcript_type, fallback, granularity)
 
546
  preprocess_args.raw_data_dir, preprocess_args.raw_data_file)
547
 
548
  if preprocess_args.update_database:
549
+ logger.info('Updating database')
550
  for mirror in MIRRORS:
551
+ logger.info(f'Downloading from {mirror}')
552
  if download_file(mirror, raw_dataset_path):
553
  break
554
+ logger.warning('Failed, trying next')
555
 
556
  os.makedirs(dataset_args.data_dir, exist_ok=True)
557
  processed_db_path = os.path.join(
 
561
  @lru_cache(maxsize=1)
562
  def read_db():
563
  if not preprocess_args.overwrite and os.path.exists(processed_db_path):
564
+ logger.info('Using cached processed database (use `--overwrite` to avoid this behaviour).')
 
565
  with open(processed_db_path) as fp:
566
  return json.load(fp)
567
+ logger.info('Processing raw database')
568
  db = {}
569
 
570
  allowed_categories = list(map(str.lower, CATGEGORY_OPTIONS))
 
620
 
621
  # Remove duplicate sponsor segments by choosing best (most votes)
622
  if not preprocess_args.keep_duplicate_segments:
623
+ logger.info('Remove duplicate segments')
624
  for key in db:
625
  db[key] = remove_duplicate_segments(db[key])
626
 
 
648
 
649
  # TODO remove videos that contain a full-video label?
650
 
651
+ logger.info(f'Saved {len(db)} videos')
652
 
653
  with open(processed_db_path, 'w') as fp:
654
  json.dump(db, fp)
 
662
  # 'userID', 'timeSubmitted', 'views', 'category', 'actionType', 'service', 'videoDuration',
663
  # 'hidden', 'reputation', 'shadowHidden', 'hashedVideoID', 'userAgent', 'description'
664
  if preprocess_args.do_transcribe:
665
+ logger.info('Collecting videos')
666
  parsed_database = read_db()
667
 
668
  # Remove transcripts already processed
 
680
  get_words(video_id)
681
  return video_id
682
 
683
+ logger.info('Setting up ThreadPoolExecutor')
684
  with concurrent.futures.ThreadPoolExecutor(max_workers=preprocess_args.num_jobs) as pool, \
685
  tqdm(total=len(video_ids)) as progress:
686
 
 
700
  progress.update()
701
 
702
  except KeyboardInterrupt:
703
+ logger.info('Gracefully shutting down: Cancelling unscheduled tasks')
704
 
705
  # only futures that are not done will prevent exiting
706
  for future in to_process:
707
  future.cancel()
708
 
709
+ logger.info('Waiting for in-progress tasks to complete')
710
  concurrent.futures.wait(to_process, timeout=None)
711
+ logger.info('Cancellation successful')
712
 
713
  final_path = os.path.join(
714
  dataset_args.data_dir, dataset_args.processed_file)
715
 
716
  if preprocess_args.do_create:
717
+ logger.info('Create final data')
718
 
719
  final_data = {}
720
 
 
788
  dataset_args.data_dir, dataset_args.negative_file)
789
 
790
  if preprocess_args.do_generate:
791
+ logger.info('Generating')
792
  # max_videos=preprocess_args.max_videos,
793
  # max_segments=preprocess_args.max_segments,
794
  # , max_videos, max_segments
 
870
  print(json.dumps(d), file=negative)
871
 
872
  if preprocess_args.do_split:
873
+ logger.info('Splitting')
874
+ logger.info('Read files')
875
 
876
  with open(positive_file, encoding='utf-8') as positive:
877
  sponsors = positive.readlines()
 
879
  with open(negative_file, encoding='utf-8') as negative:
880
  non_sponsors = negative.readlines()
881
 
882
+ logger.info('Shuffle')
883
  random.shuffle(sponsors)
884
  random.shuffle(non_sponsors)
885
 
886
+ logger.info('Calculate ratios')
887
  # Ensure correct ratio of positive to negative segments
888
  percentage_negative = 1 - preprocess_args.percentage_positive
889
 
 
903
  excess = non_sponsors[z:]
904
  non_sponsors = non_sponsors[:z]
905
 
906
+ logger.info('Join')
907
  all_labelled_segments = sponsors + non_sponsors
908
 
909
  random.shuffle(all_labelled_segments)
910
 
911
+ logger.info('Split')
912
  ratios = [preprocess_args.train_split,
913
  preprocess_args.test_split,
914
  preprocess_args.valid_split]
 
929
  with open(outfile, 'w', encoding='utf-8') as fp:
930
  fp.writelines(items)
931
  else:
932
+ logger.info(f'Skipping {name}')
933
 
934
+ logger.info('Write')
935
  # Save excess items
936
  excess_path = os.path.join(
937
  dataset_args.data_dir, dataset_args.excess_file)
 
939
  with open(excess_path, 'w', encoding='utf-8') as fp:
940
  fp.writelines(excess)
941
  else:
942
+ logger.info(f'Skipping {dataset_args.excess_file}')
943
 
944
+ logger.info(f'Finished splitting: {len(sponsors)} sponsors, {len(non_sponsors)} non sponsors')
 
945
 
946
 
947
  def split(arr, ratios):