Joshua Lochner commited on
Commit
fb87012
·
1 Parent(s): 02e576a

Improve caching and downloading of classifier for predictions

Browse files
Files changed (2) hide show
  1. src/evaluate.py +4 -3
  2. src/predict.py +20 -8
src/evaluate.py CHANGED
@@ -205,7 +205,7 @@ def main():
205
 
206
  evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
207
 
208
- model, tokenizer = get_model_tokenizer(evaluation_args.model_path)
209
 
210
  # # TODO find better way of evaluating videos not trained on
211
  # dataset = load_dataset('json', data_files=os.path.join(
@@ -313,8 +313,9 @@ def main():
313
  [w['text'] for w in missed_segment['words']]), '"', sep='')
314
  print('\t\tCategory:',
315
  missed_segment.get('category'))
316
- print('\t\tProbability:',
317
- missed_segment.get('probability'))
 
318
 
319
  segments_to_submit.append({
320
  'segment': [missed_segment['start'], missed_segment['end']],
 
205
 
206
  evaluation_args, dataset_args, segmentation_args, classifier_args, _ = hf_parser.parse_args_into_dataclasses()
207
 
208
+ model, tokenizer = get_model_tokenizer(evaluation_args.model_path, evaluation_args.cache_dir)
209
 
210
  # # TODO find better way of evaluating videos not trained on
211
  # dataset = load_dataset('json', data_files=os.path.join(
 
313
  [w['text'] for w in missed_segment['words']]), '"', sep='')
314
  print('\t\tCategory:',
315
  missed_segment.get('category'))
316
+ if 'probability' in missed_segment:
317
+ print('\t\tProbability:',
318
+ missed_segment['probability'])
319
 
320
  segments_to_submit.append({
321
  'segment': [missed_segment['start'], missed_segment['end']],
src/predict.py CHANGED
@@ -11,8 +11,8 @@ from segment import (
11
  SegmentationArguments
12
  )
13
  import preprocess
14
- from errors import TranscriptError, ModelLoadError
15
- from model import get_classifier_vectorizer, get_model_tokenizer
16
  from transformers import HfArgumentParser
17
  from transformers.trainer_utils import get_last_checkpoint
18
  from dataclasses import dataclass, field
@@ -29,6 +29,7 @@ class TrainingOutputArguments:
29
  'help': 'Path to pretrained model used for prediction'
30
  }
31
  )
 
32
 
33
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
34
  'output_dir']
@@ -43,7 +44,8 @@ class TrainingOutputArguments:
43
  self.model_path = last_checkpoint
44
  return
45
 
46
- raise ModelLoadError('Unable to find model, explicitly set `--model_path`')
 
47
 
48
 
49
  @dataclass
@@ -65,6 +67,13 @@ MERGE_TIME_WITHIN = 8 # Merge predictions if they are within x seconds
65
 
66
  @dataclass(frozen=True, eq=True)
67
  class ClassifierArguments:
 
 
 
 
 
 
 
68
  classifier_dir: Optional[str] = field(
69
  default='classifiers',
70
  metadata={
@@ -90,7 +99,6 @@ class ClassifierArguments:
90
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
91
 
92
 
93
- # classifier, vectorizer,
94
  def filter_and_add_probabilities(predictions, classifier_args):
95
  """Use classifier to filter predictions"""
96
  if not predictions:
@@ -160,8 +168,11 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
160
 
161
  # TODO add back
162
  if classifier_args is not None:
163
- predictions = filter_and_add_probabilities(
164
- predictions, classifier_args)
 
 
 
165
 
166
  return predictions
167
 
@@ -290,7 +301,7 @@ def main():
290
  print('No video ID supplied. Use `--video_id`.')
291
  return
292
 
293
- model, tokenizer = get_model_tokenizer(predict_args.model_path)
294
 
295
  predict_args.video_id = predict_args.video_id.strip()
296
  predictions = predict(predict_args.video_id, model, tokenizer,
@@ -308,8 +319,9 @@ def main():
308
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
309
  print('Time:', seconds_to_time(
310
  prediction['start']), '\u2192', seconds_to_time(prediction['end']))
311
- print('Probability:', prediction.get('probability'))
312
  print('Category:', prediction.get('category'))
 
 
313
  print()
314
 
315
 
 
11
  SegmentationArguments
12
  )
13
  import preprocess
14
+ from errors import TranscriptError, ModelLoadError, ClassifierLoadError
15
+ from model import ModelArguments, get_classifier_vectorizer, get_model_tokenizer
16
  from transformers import HfArgumentParser
17
  from transformers.trainer_utils import get_last_checkpoint
18
  from dataclasses import dataclass, field
 
29
  'help': 'Path to pretrained model used for prediction'
30
  }
31
  )
32
+ cache_dir: Optional[str] = ModelArguments.__dataclass_fields__['cache_dir']
33
 
34
  output_dir: Optional[str] = OutputArguments.__dataclass_fields__[
35
  'output_dir']
 
44
  self.model_path = last_checkpoint
45
  return
46
 
47
+ raise ModelLoadError(
48
+ 'Unable to find model, explicitly set `--model_path`')
49
 
50
 
51
  @dataclass
 
67
 
68
  @dataclass(frozen=True, eq=True)
69
  class ClassifierArguments:
70
+ classifier_model: Optional[str] = field(
71
+ default='Xenova/sponsorblock-classifier',
72
+ metadata={
73
+ 'help': 'Use a pretrained classifier'
74
+ }
75
+ )
76
+
77
  classifier_dir: Optional[str] = field(
78
  default='classifiers',
79
  metadata={
 
99
  default=0.5, metadata={'help': 'Remove all predictions whose classification probability is below this threshold.'})
100
 
101
 
 
102
  def filter_and_add_probabilities(predictions, classifier_args):
103
  """Use classifier to filter predictions"""
104
  if not predictions:
 
168
 
169
  # TODO add back
170
  if classifier_args is not None:
171
+ try:
172
+ predictions = filter_and_add_probabilities(
173
+ predictions, classifier_args)
174
+ except ClassifierLoadError:
175
+ print('Unable to load classifer')
176
 
177
  return predictions
178
 
 
301
  print('No video ID supplied. Use `--video_id`.')
302
  return
303
 
304
+ model, tokenizer = get_model_tokenizer(predict_args.model_path, predict_args.cache_dir)
305
 
306
  predict_args.video_id = predict_args.video_id.strip()
307
  predictions = predict(predict_args.video_id, model, tokenizer,
 
319
  ' '.join([w['text'] for w in prediction['words']]), '"', sep='')
320
  print('Time:', seconds_to_time(
321
  prediction['start']), '\u2192', seconds_to_time(prediction['end']))
 
322
  print('Category:', prediction.get('category'))
323
+ if 'probability' in prediction:
324
+ print('Probability:', prediction['probability'])
325
  print()
326
 
327