Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
508e8b2
1
Parent(s):
2115d78
Fix classifier preprocessing
Browse files- src/preprocess.py +29 -39
src/preprocess.py
CHANGED
@@ -9,7 +9,7 @@ import segment
|
|
9 |
from tqdm import tqdm
|
10 |
from dataclasses import dataclass, field
|
11 |
from transformers import HfArgumentParser
|
12 |
-
from shared import
|
13 |
import csv
|
14 |
import re
|
15 |
import random
|
@@ -418,9 +418,9 @@ class PreprocessArguments:
|
|
418 |
num_jobs: int = field(
|
419 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
420 |
|
421 |
-
overwrite: bool = field(
|
422 |
-
|
423 |
-
)
|
424 |
|
425 |
do_generate: bool = field(
|
426 |
default=False, metadata={'help': 'Generate labelled data.'}
|
@@ -538,11 +538,11 @@ def main():
|
|
538 |
# TODO process all valid possible items and then do filtering only later
|
539 |
@lru_cache(maxsize=1)
|
540 |
def read_db():
|
541 |
-
if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
542 |
-
|
543 |
-
|
544 |
-
|
545 |
-
|
546 |
logger.info('Processing raw database')
|
547 |
db = {}
|
548 |
|
@@ -916,11 +916,8 @@ def main():
|
|
916 |
# Output training, testing and validation data
|
917 |
for name, items in splits.items():
|
918 |
outfile = os.path.join(dataset_args.data_dir, name)
|
919 |
-
|
920 |
-
|
921 |
-
fp.writelines(items)
|
922 |
-
else:
|
923 |
-
logger.info(f'Skipping {name}')
|
924 |
|
925 |
classifier_splits = {
|
926 |
dataset_args.c_train_file: train_data,
|
@@ -933,31 +930,24 @@ def main():
|
|
933 |
# Output training, testing and validation data
|
934 |
for name, items in classifier_splits.items():
|
935 |
outfile = os.path.join(dataset_args.data_dir, name)
|
936 |
-
|
937 |
-
|
938 |
-
|
939 |
-
|
940 |
-
|
941 |
-
|
942 |
-
|
943 |
-
|
944 |
-
|
945 |
-
|
946 |
-
'
|
947 |
-
|
948 |
-
|
949 |
-
|
950 |
-
|
951 |
-
|
952 |
-
|
953 |
-
|
954 |
-
})
|
955 |
-
|
956 |
-
for labelled_item in labelled_items:
|
957 |
-
print(json.dumps(labelled_item), file=fp)
|
958 |
-
|
959 |
-
else:
|
960 |
-
logger.info(f'Skipping {name}')
|
961 |
|
962 |
logger.info('Write')
|
963 |
# Save excess items
|
|
|
9 |
from tqdm import tqdm
|
10 |
from dataclasses import dataclass, field
|
11 |
from transformers import HfArgumentParser
|
12 |
+
from shared import extract_sponsor_matches_from_text, ACTION_OPTIONS, CATEGORIES, CATGEGORY_OPTIONS, START_SEGMENT_TEMPLATE, END_SEGMENT_TEMPLATE, GeneralArguments, CustomTokens
|
13 |
import csv
|
14 |
import re
|
15 |
import random
|
|
|
418 |
num_jobs: int = field(
|
419 |
default=4, metadata={'help': 'Number of transcripts to download in parallel'})
|
420 |
|
421 |
+
# overwrite: bool = field(
|
422 |
+
# default=False, metadata={'help': 'Overwrite training, testing and validation data, if present.'}
|
423 |
+
# )
|
424 |
|
425 |
do_generate: bool = field(
|
426 |
default=False, metadata={'help': 'Generate labelled data.'}
|
|
|
538 |
# TODO process all valid possible items and then do filtering only later
|
539 |
@lru_cache(maxsize=1)
|
540 |
def read_db():
|
541 |
+
# if not preprocess_args.overwrite and os.path.exists(processed_db_path):
|
542 |
+
# logger.info(
|
543 |
+
# 'Using cached processed database (use `--overwrite` to avoid this behaviour).')
|
544 |
+
# with open(processed_db_path) as fp:
|
545 |
+
# return json.load(fp)
|
546 |
logger.info('Processing raw database')
|
547 |
db = {}
|
548 |
|
|
|
916 |
# Output training, testing and validation data
|
917 |
for name, items in splits.items():
|
918 |
outfile = os.path.join(dataset_args.data_dir, name)
|
919 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
920 |
+
fp.writelines(items)
|
|
|
|
|
|
|
921 |
|
922 |
classifier_splits = {
|
923 |
dataset_args.c_train_file: train_data,
|
|
|
930 |
# Output training, testing and validation data
|
931 |
for name, items in classifier_splits.items():
|
932 |
outfile = os.path.join(dataset_args.data_dir, name)
|
933 |
+
with open(outfile, 'w', encoding='utf-8') as fp:
|
934 |
+
for item in items:
|
935 |
+
parsed_item = json.loads(item) # TODO add uuid
|
936 |
+
|
937 |
+
matches = extract_sponsor_matches_from_text(parsed_item['extracted'])
|
938 |
+
|
939 |
+
if matches:
|
940 |
+
for match in matches:
|
941 |
+
print(json.dumps({
|
942 |
+
'text': match['text'],
|
943 |
+
'label': CATEGORIES.index(match['category'])
|
944 |
+
}), file=fp)
|
945 |
+
else:
|
946 |
+
print(json.dumps({
|
947 |
+
'text': parsed_item['text'],
|
948 |
+
'label': none_category
|
949 |
+
}), file=fp)
|
950 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
951 |
|
952 |
logger.info('Write')
|
953 |
# Save excess items
|