Joshua Lochner commited on
Commit
643d00a
·
1 Parent(s): a0ca50e

Remove PreprocessingDatasetArguments class

Browse files
Files changed (2) hide show
  1. src/preprocess.py +9 -24
  2. src/shared.py +2 -1
src/preprocess.py CHANGED
@@ -420,6 +420,14 @@ class PreprocessArguments:
420
  do_split: bool = field(
421
  default=False, metadata={'help': 'Generate training, testing and validation data.'}
422
  )
 
 
 
 
 
 
 
 
423
  percentage_positive: float = field(
424
  default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
425
 
@@ -488,29 +496,6 @@ def download_file(url, filename):
488
  return total_bytes == os.path.getsize(filename)
489
 
490
 
491
- @dataclass
492
- class PreprocessingDatasetArguments(DatasetArguments):
493
- # excess_file: Optional[str] = field(
494
- # default='excess.json',
495
- # metadata={
496
- # 'help': 'The excess segments left after the split'
497
- # },
498
- # )
499
-
500
- positive_file: Optional[str] = field(
501
- default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
502
- )
503
- negative_file: Optional[str] = field(
504
- default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
505
- )
506
-
507
- def __post_init__(self):
508
- # TODO check if train/validation datasets exist
509
- if self.train_file is None and self.validation_file is None:
510
- raise ValueError(
511
- 'Need either a dataset name or a training/validation file.')
512
-
513
-
514
  def main():
515
  # Responsible for getting transcrips using youtube_transcript_api,
516
  # then labelling it according to SponsorBlock's API
@@ -519,7 +504,7 @@ def main():
519
  # Generate final.json from sponsorTimes.csv
520
  hf_parser = HfArgumentParser((
521
  PreprocessArguments,
522
- PreprocessingDatasetArguments,
523
  segment.SegmentationArguments,
524
  model_module.ModelArguments,
525
  GeneralArguments
 
420
  do_split: bool = field(
421
  default=False, metadata={'help': 'Generate training, testing and validation data.'}
422
  )
423
+
424
+ positive_file: Optional[str] = field(
425
+ default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
426
+ )
427
+ negative_file: Optional[str] = field(
428
+ default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
429
+ )
430
+
431
  percentage_positive: float = field(
432
  default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
433
 
 
496
  return total_bytes == os.path.getsize(filename)
497
 
498
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
499
  def main():
500
  # Responsible for getting transcrips using youtube_transcript_api,
501
  # then labelling it according to SponsorBlock's API
 
504
  # Generate final.json from sponsorTimes.csv
505
  hf_parser = HfArgumentParser((
506
  PreprocessArguments,
507
+ DatasetArguments,
508
  segment.SegmentationArguments,
509
  model_module.ModelArguments,
510
  GeneralArguments
src/shared.py CHANGED
@@ -137,7 +137,8 @@ class DatasetArguments:
137
  def __post_init__(self):
138
  if self.train_file is None or self.validation_file is None:
139
  raise ValueError(
140
- "Need either a GLUE task, a training/validation file or a dataset name.")
 
141
  else:
142
  train_extension = self.train_file.split(".")[-1]
143
  assert train_extension in [
 
137
  def __post_init__(self):
138
  if self.train_file is None or self.validation_file is None:
139
  raise ValueError(
140
+ 'Need either a dataset name or a training/validation file.')
141
+
142
  else:
143
  train_extension = self.train_file.split(".")[-1]
144
  assert train_extension in [