Joshua Lochner commited on
Commit
508e8b2
·
1 Parent(s): 2115d78

Fix classifier preprocessing

Browse files
Files changed (1) hide show
  1. src/preprocess.py +29 -39
src/preprocess.py CHANGED
@@ -9,7 +9,7 @@ import segment
9
  from tqdm import tqdm
10
  from dataclasses import dataclass, field
11
  from transformers import HfArgumentParser
12
- from shared import extract_sponsor_matches, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
13
  import csv
14
  import re
15
  import random
@@ -418,9 +418,9 @@ class PreprocessArguments:
418
  num_jobs: int = field(
419
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
420
 
421
- overwrite: bool = field(
422
- default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
423
- )
424
 
425
  do_generate: bool = field(
426
  default=False, metadata={'help': 'Generate labelled data.'}
@@ -538,11 +538,11 @@ def main():
538
  # TODO process all valid possible items and then do filtering only later
539
  @lru_cache(maxsize=1)
540
  def read_db():
541
- if not preprocess_args.overwrite and os.path.exists(processed_db_path):
542
- logger.info(
543
- 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
544
- with open(processed_db_path) as fp:
545
- return json.load(fp)
546
  logger.info('Processing raw database')
547
  db = {}
548
 
@@ -916,11 +916,8 @@ def main():
916
  # Output training, testing and validation data
917
  for name, items in splits.items():
918
  outfile = os.path.join(dataset_args.data_dir, name)
919
- if not os.path.exists(outfile) or preprocess_args.overwrite:
920
- with open(outfile, 'w', encoding='utf-8') as fp:
921
- fp.writelines(items)
922
- else:
923
- logger.info(f'Skipping {name}')
924
 
925
  classifier_splits = {
926
  dataset_args.c_train_file: train_data,
@@ -933,31 +930,24 @@ def main():
933
  # Output training, testing and validation data
934
  for name, items in classifier_splits.items():
935
  outfile = os.path.join(dataset_args.data_dir, name)
936
- if not os.path.exists(outfile) or preprocess_args.overwrite:
937
- with open(outfile, 'w', encoding='utf-8') as fp:
938
- for i in items:
939
- x = json.loads(i) # TODO add uuid
940
- labelled_items = []
941
-
942
- matches = extract_sponsor_matches(x['extracted'])
943
-
944
- if x['extracted'] == CustomTokens.NO_SEGMENT.value:
945
- labelled_items.append({
946
- 'text': x['text'],
947
- 'label': none_category
948
- })
949
- else:
950
- for match in matches:
951
- labelled_items.append({
952
- 'text': match['text'],
953
- 'label': CATEGORIES.index(match['category'])
954
- })
955
-
956
- for labelled_item in labelled_items:
957
- print(json.dumps(labelled_item), file=fp)
958
-
959
- else:
960
- logger.info(f'Skipping {name}')
961
 
962
  logger.info('Write')
963
  # Save excess items
 
9
  from tqdm import tqdm
10
  from dataclasses import dataclass, field
11
  from transformers import HfArgumentParser
12
+ from shared import extract_sponsor_matches_from_text, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
13
  import csv
14
  import re
15
  import random
 
418
  num_jobs: int = field(
419
  default=4, metadata={'help': 'Number of transcripts to download in parallel'})
420
 
421
+ # overwrite: bool = field(
422
+ # default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
423
+ # )
424
 
425
  do_generate: bool = field(
426
  default=False, metadata={'help': 'Generate labelled data.'}
 
538
  # TODO process all valid possible items and then do filtering only later
539
  @lru_cache(maxsize=1)
540
  def read_db():
541
+ # if not preprocess_args.overwrite and os.path.exists(processed_db_path):
542
+ # logger.info(
543
+ # 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
544
+ # with open(processed_db_path) as fp:
545
+ # return json.load(fp)
546
  logger.info('Processing raw database')
547
  db = {}
548
 
 
916
  # Output training, testing and validation data
917
  for name, items in splits.items():
918
  outfile = os.path.join(dataset_args.data_dir, name)
919
+ with open(outfile, 'w', encoding='utf-8') as fp:
920
+ fp.writelines(items)
 
 
 
921
 
922
  classifier_splits = {
923
  dataset_args.c_train_file: train_data,
 
930
  # Output training, testing and validation data
931
  for name, items in classifier_splits.items():
932
  outfile = os.path.join(dataset_args.data_dir, name)
933
+ with open(outfile, 'w', encoding='utf-8') as fp:
934
+ for item in items:
935
+ parsed_item = json.loads(item) # TODO add uuid
936
+
937
+ matches = extract_sponsor_matches_from_text(parsed_item['extracted'])
938
+
939
+ if matches:
940
+ for match in matches:
941
+ print(json.dumps({
942
+ 'text': match['text'],
943
+ 'label': CATEGORIES.index(match['category'])
944
+ }), file=fp)
945
+ else:
946
+ print(json.dumps({
947
+ 'text': parsed_item['text'],
948
+ 'label': none_category
949
+ }), file=fp)
950
+
 
 
 
 
 
 
 
951
 
952
  logger.info('Write')
953
  # Save excess items