Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
fb87012
1
Parent(s):
02e576a
Improve caching and downloading of classifier for predictions
Browse files- src/evaluate.py +4 -3
- src/predict.py +20 -8
src/evaluate.py
CHANGED
@@ -205,7 +205,7 @@ def main():
|
|
205 |
|
206 |
evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
|
207 |
|
208 |
-
model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
|
209 |
|
210 |
# # TODO find better way of evaluating videos not trained on
|
211 |
# dataset = load_dataset('json', data_files=os.path.join(
|
@@ -313,8 +313,9 @@ def main():
|
|
313 |
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
314 |
print('\t\tCategory:',
|
315 |
missed_segment.get('category'))
|
316 |
-
|
317 |
-
|
|
|
318 |
|
319 |
segments_to_submit.append({
|
320 |
'segment': [missed_segment['start'], missed_segment['end']],
|
|
|
205 |
|
206 |
evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
|
207 |
|
208 |
+
model, tokenizer = get_model_tokenizer(evaluation_args.model_path, evaluation_args.cache_dir)
|
209 |
|
210 |
# # TODO find better way of evaluating videos not trained on
|
211 |
# dataset = load_dataset('json', data_files=os.path.join(
|
|
|
313 |
[w['text'] for w in missed_segment['words']]), '"', sep='')
|
314 |
print('\t\tCategory:',
|
315 |
missed_segment.get('category'))
|
316 |
+
if 'probability' in missed_segment:
|
317 |
+
print('\t\tProbability:',
|
318 |
+
missed_segment['probability'])
|
319 |
|
320 |
segments_to_submit.append({
|
321 |
'segment': [missed_segment['start'], missed_segment['end']],
|
src/predict.py
CHANGED
@@ -11,8 +11,8 @@ from segment import (
|
|
11 |
SegmentationArguments
|
12 |
)
|
13 |
import preprocess
|
14 |
-
from errors import TranscriptError, ModelLoadError
|
15 |
-
from model import get_classifier_vectorizer, get_model_tokenizer
|
16 |
from transformers import HfArgumentParser
|
17 |
from transformers.trainer_utils import get_last_checkpoint
|
18 |
from dataclasses import dataclass, field
|
@@ -29,6 +29,7 @@ class TrainingOutputArguments:
|
|
29 |
'help': 'Path to pretrained model used for prediction'
|
30 |
}
|
31 |
)
|
|
|
32 |
|
33 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
34 |
'output_dir']
|
@@ -43,7 +44,8 @@ class TrainingOutputArguments:
|
|
43 |
self.model_path = last_checkpoint
|
44 |
return
|
45 |
|
46 |
-
raise ModelLoadError(
|
|
|
47 |
|
48 |
|
49 |
@dataclass
|
@@ -65,6 +67,13 @@ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
|
|
65 |
|
66 |
@dataclass(frozen=True, eq=True)
|
67 |
class ClassifierArguments:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
68 |
classifier_dir: Optional[str] = field(
|
69 |
default='classifiers',
|
70 |
metadata={
|
@@ -90,7 +99,6 @@ class ClassifierArguments:
|
|
90 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
91 |
|
92 |
|
93 |
-
# classifier, vectorizer,
|
94 |
def filter_and_add_probabilities(predictions, classifier_args):
|
95 |
"""Use classifier to filter predictions"""
|
96 |
if not predictions:
|
@@ -160,8 +168,11 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
160 |
|
161 |
# TODO add back
|
162 |
if classifier_args is not None:
|
163 |
-
|
164 |
-
predictions
|
|
|
|
|
|
|
165 |
|
166 |
return predictions
|
167 |
|
@@ -290,7 +301,7 @@ def main():
|
|
290 |
print('No video ID supplied. Use `--video_id`.')
|
291 |
return
|
292 |
|
293 |
-
model, tokenizer = get_model_tokenizer(predict_args.model_path)
|
294 |
|
295 |
predict_args.video_id = predict_args.video_id.strip()
|
296 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
@@ -308,8 +319,9 @@ def main():
|
|
308 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
309 |
print('Time:', seconds_to_time(
|
310 |
prediction['start']), '\u2192', seconds_to_time(prediction['end']))
|
311 |
-
print('Probability:', prediction.get('probability'))
|
312 |
print('Category:', prediction.get('category'))
|
|
|
|
|
313 |
print()
|
314 |
|
315 |
|
|
|
11 |
SegmentationArguments
|
12 |
)
|
13 |
import preprocess
|
14 |
+
from errors import TranscriptError, ModelLoadError, ClassifierLoadError
|
15 |
+
from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
|
16 |
from transformers import HfArgumentParser
|
17 |
from transformers.trainer_utils import get_last_checkpoint
|
18 |
from dataclasses import dataclass, field
|
|
|
29 |
'help': 'Path to pretrained model used for prediction'
|
30 |
}
|
31 |
)
|
32 |
+
cache_dir: Optional[str] = ModelArguments.__dataclass_fields__['cache_dir']
|
33 |
|
34 |
output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
|
35 |
'output_dir']
|
|
|
44 |
self.model_path = last_checkpoint
|
45 |
return
|
46 |
|
47 |
+
raise ModelLoadError(
|
48 |
+
'Unable to find model, explicitly set `--model_path`')
|
49 |
|
50 |
|
51 |
@dataclass
|
|
|
67 |
|
68 |
@dataclass(frozen=True, eq=True)
|
69 |
class ClassifierArguments:
|
70 |
+
classifier_model: Optional[str] = field(
|
71 |
+
default='Xenova/sponsorblock-classifier',
|
72 |
+
metadata={
|
73 |
+
'help': 'Use a pretrained classifier'
|
74 |
+
}
|
75 |
+
)
|
76 |
+
|
77 |
classifier_dir: Optional[str] = field(
|
78 |
default='classifiers',
|
79 |
metadata={
|
|
|
99 |
default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
|
100 |
|
101 |
|
|
|
102 |
def filter_and_add_probabilities(predictions, classifier_args):
|
103 |
"""Use classifier to filter predictions"""
|
104 |
if not predictions:
|
|
|
168 |
|
169 |
# TODO add back
|
170 |
if classifier_args is not None:
|
171 |
+
try:
|
172 |
+
predictions = filter_and_add_probabilities(
|
173 |
+
predictions, classifier_args)
|
174 |
+
except ClassifierLoadError:
|
175 |
+
print('Unable to load classifer')
|
176 |
|
177 |
return predictions
|
178 |
|
|
|
301 |
print('No video ID supplied. Use `--video_id`.')
|
302 |
return
|
303 |
|
304 |
+
model, tokenizer = get_model_tokenizer(predict_args.model_path, predict_args.cache_dir)
|
305 |
|
306 |
predict_args.video_id = predict_args.video_id.strip()
|
307 |
predictions = predict(predict_args.video_id, model, tokenizer,
|
|
|
319 |
' '.join([w['text'] for w in prediction['words']]), '"', sep='')
|
320 |
print('Time:', seconds_to_time(
|
321 |
prediction['start']), '\u2192', seconds_to_time(prediction['end']))
|
|
|
322 |
print('Category:', prediction.get('category'))
|
323 |
+
if 'probability' in prediction:
|
324 |
+
print('Probability:', prediction['probability'])
|
325 |
print()
|
326 |
|
327 |
|