Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
643d00a
1
Parent(s):
a0ca50e
Remove PreprocessingDatasetArguments class
Browse files- src/preprocess.py +9 -24
- src/shared.py +2 -1
src/preprocess.py
CHANGED
|
@@ -420,6 +420,14 @@ class PreprocessArguments:
|
|
| 420 |
do_split: bool = field(
|
| 421 |
default=False, metadata={'help': 'Generate training, testing and validation data.'}
|
| 422 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 423 |
percentage_positive: float = field(
|
| 424 |
default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
|
| 425 |
|
|
@@ -488,29 +496,6 @@ def download_file(url, filename):
|
|
| 488 |
return total_bytes == os.path.getsize(filename)
|
| 489 |
|
| 490 |
|
| 491 |
-
@dataclass
|
| 492 |
-
class PreprocessingDatasetArguments(DatasetArguments):
|
| 493 |
-
# excess_file: Optional[str] = field(
|
| 494 |
-
# default='excess.json',
|
| 495 |
-
# metadata={
|
| 496 |
-
# 'help': 'The excess segments left after the split'
|
| 497 |
-
# },
|
| 498 |
-
# )
|
| 499 |
-
|
| 500 |
-
positive_file: Optional[str] = field(
|
| 501 |
-
default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
|
| 502 |
-
)
|
| 503 |
-
negative_file: Optional[str] = field(
|
| 504 |
-
default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
|
| 505 |
-
)
|
| 506 |
-
|
| 507 |
-
def __post_init__(self):
|
| 508 |
-
# TODO check if train/validation datasets exist
|
| 509 |
-
if self.train_file is None and self.validation_file is None:
|
| 510 |
-
raise ValueError(
|
| 511 |
-
'Need either a dataset name or a training/validation file.')
|
| 512 |
-
|
| 513 |
-
|
| 514 |
def main():
|
| 515 |
# Responsible for getting transcrips using youtube_transcript_api,
|
| 516 |
# then labelling it according to SponsorBlock's API
|
|
@@ -519,7 +504,7 @@ def main():
|
|
| 519 |
# Generate final.json from sponsorTimes.csv
|
| 520 |
hf_parser = HfArgumentParser((
|
| 521 |
PreprocessArguments,
|
| 522 |
-
|
| 523 |
segment.SegmentationArguments,
|
| 524 |
model_module.ModelArguments,
|
| 525 |
GeneralArguments
|
|
|
|
| 420 |
do_split: bool = field(
|
| 421 |
default=False, metadata={'help': 'Generate training, testing and validation data.'}
|
| 422 |
)
|
| 423 |
+
|
| 424 |
+
positive_file: Optional[str] = field(
|
| 425 |
+
default='sponsor_segments.json', metadata={'help': 'File to output sponsored segments to (a jsonlines file).'}
|
| 426 |
+
)
|
| 427 |
+
negative_file: Optional[str] = field(
|
| 428 |
+
default='normal_segments.json', metadata={'help': 'File to output normal segments to (a jsonlines file).'}
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
percentage_positive: float = field(
|
| 432 |
default=0.5, metadata={'help': 'Ratio of positive (sponsor) segments to include in final output'})
|
| 433 |
|
|
|
|
| 496 |
return total_bytes == os.path.getsize(filename)
|
| 497 |
|
| 498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 499 |
def main():
|
| 500 |
# Responsible for getting transcrips using youtube_transcript_api,
|
| 501 |
# then labelling it according to SponsorBlock's API
|
|
|
|
| 504 |
# Generate final.json from sponsorTimes.csv
|
| 505 |
hf_parser = HfArgumentParser((
|
| 506 |
PreprocessArguments,
|
| 507 |
+
DatasetArguments,
|
| 508 |
segment.SegmentationArguments,
|
| 509 |
model_module.ModelArguments,
|
| 510 |
GeneralArguments
|
src/shared.py
CHANGED
|
@@ -137,7 +137,8 @@ class DatasetArguments:
|
|
| 137 |
def __post_init__(self):
|
| 138 |
if self.train_file is None or self.validation_file is None:
|
| 139 |
raise ValueError(
|
| 140 |
-
|
|
|
|
| 141 |
else:
|
| 142 |
train_extension = self.train_file.split(".")[-1]
|
| 143 |
assert train_extension in [
|
|
|
|
| 137 |
def __post_init__(self):
|
| 138 |
if self.train_file is None or self.validation_file is None:
|
| 139 |
raise ValueError(
|
| 140 |
+
'Need either a dataset name or a training/validation file.')
|
| 141 |
+
|
| 142 |
else:
|
| 143 |
train_extension = self.train_file.split(".")[-1]
|
| 144 |
assert train_extension in [
|