Joshua Lochner commited on
Commit
dccb47b
·
1 Parent(s): 67d0193

Separate missing and incorrect detection logic

Browse files
Files changed (1) hide show
  1. src/evaluate.py +140 -78
src/evaluate.py CHANGED
@@ -1,8 +1,8 @@
1
 
2
  from model import get_model_tokenizer_classifier, InferenceArguments
3
- from utils import jaccard
4
  from transformers import HfArgumentParser
5
- from preprocess import get_words
6
  from shared import GeneralArguments, DatasetArguments
7
  from predict import predict
8
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
@@ -31,6 +31,19 @@ class EvaluationArguments(InferenceArguments):
31
  }
32
  )
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
36
  """Attach sponsor segments to closest prediction"""
@@ -46,7 +59,7 @@ def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
46
  prediction['best_overlap'] = j
47
  prediction['best_sponsorship'] = sponsor_segment
48
 
49
- # return sponsor_segments
50
 
51
 
52
  def calculate_metrics(labelled_words, predictions):
@@ -130,6 +143,10 @@ def main():
130
 
131
  evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
132
 
 
 
 
 
133
  # Load labelled data:
134
  final_path = os.path.join(
135
  dataset_args.data_dir, dataset_args.processed_file)
@@ -158,14 +175,22 @@ def main():
158
  if evaluation_args.max_videos is not None:
159
  video_ids = video_ids[:evaluation_args.max_videos]
160
 
161
- # TODO option to choose categories
162
 
163
- total_accuracy = 0
164
- total_precision = 0
165
- total_recall = 0
166
- total_fscore = 0
 
 
167
 
168
- out_metrics = []
 
 
 
 
 
 
169
 
170
  try:
171
  with tqdm(video_ids) as progress:
@@ -176,53 +201,77 @@ def main():
176
  if not words:
177
  continue
178
 
179
- # Make predictions
180
- predictions = predict(video_id, model, tokenizer, segmentation_args,
181
- classifier=classifier,
182
- min_probability=evaluation_args.min_probability)
183
-
184
  # Get labels
185
  sponsor_segments = final_data.get(video_id)
186
- if sponsor_segments:
187
- labelled_words = add_labels_to_words(
188
- words, sponsor_segments)
189
- met = calculate_metrics(labelled_words, predictions)
190
- met['video_id'] = video_id
191
 
192
- out_metrics.append(met)
 
 
193
 
194
- total_accuracy += met['accuracy']
195
- total_precision += met['precision']
196
- total_recall += met['recall']
197
- total_fscore += met['f-score']
198
 
199
- progress.set_postfix({
200
- 'accuracy': total_accuracy/len(out_metrics),
201
- 'precision': total_precision/len(out_metrics),
202
- 'recall': total_recall/len(out_metrics),
203
- 'f-score': total_fscore/len(out_metrics)
204
- })
205
 
206
- attach_predictions_to_sponsor_segments(
207
- predictions, sponsor_segments)
 
208
 
209
- # Identify possible issues:
210
- missed_segments = [
211
- prediction for prediction in predictions if prediction['best_sponsorship'] is None]
212
 
213
- # Now, check for incorrect segments using the classifier
214
- incorrect_segments = []
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
215
 
216
  segments_to_check = []
217
  texts = [] # Texts to send through tokenizer
218
  for sponsor_segment in sponsor_segments:
219
  segment_words = extract_segment(
220
  words, sponsor_segment['start'], sponsor_segment['end'])
221
- sponsor_segment['text'] = ' '.join(x['text'] for x in segment_words)
222
- sponsor_segment['cleaned_text'] = ' '.join(x['cleaned'] for x in segment_words)
223
 
224
- duration = sponsor_segment['end'] - sponsor_segment['start']
225
- wps = (len(segment_words) / duration) if duration > 0 else 0
 
 
226
  if wps < 1.5:
227
  continue
228
 
@@ -231,18 +280,24 @@ def main():
231
  if sponsor_segment['locked']:
232
  continue
233
 
 
 
234
  texts.append(sponsor_segment['cleaned_text'])
235
  segments_to_check.append(sponsor_segment)
236
 
237
- if segments_to_check: # Segments to check
238
 
239
  segments_scores = classifier(texts)
240
 
 
241
  for segment, scores in zip(segments_to_check, segments_scores):
 
 
242
  prediction = max(scores, key=lambda x: x['score'])
243
  predicted_category = prediction['label'].lower()
244
 
245
  if predicted_category == segment['category']:
 
246
  continue # Ignore correct segments
247
 
248
  segment.update({
@@ -252,18 +307,19 @@ def main():
252
 
253
  incorrect_segments.append(segment)
254
 
255
- else:
256
- # logger.warning(f'No labels found for {video_id}')
257
- # Not in database (all segments missed)
258
- missed_segments = predictions
259
- incorrect_segments = []
 
 
 
 
 
 
260
 
261
  if missed_segments or incorrect_segments:
262
- for z in missed_segments:
263
- # Attach original text to missed segments
264
- # (Already added to incorrect segments)
265
- z['text'] = ' '.join(x['text']
266
- for x in z.pop('words', []))
267
 
268
  if evaluation_args.output_as_json:
269
  to_print = {'video_id': video_id}
@@ -274,23 +330,25 @@ def main():
274
  if incorrect_segments:
275
  to_print['incorrect'] = incorrect_segments
276
 
277
- print(json.dumps(to_print))
 
278
  else:
279
- print(
280
  f'Issues identified for {video_id} (#{video_index})')
281
  # Potentially missed segments (model predicted, but not in database)
282
  if missed_segments:
283
- print(' - Missed segments:')
284
  segments_to_submit = []
285
  for i, missed_segment in enumerate(missed_segments, start=1):
286
- print(f'\t#{i}:', seconds_to_time(
287
  missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
288
- print('\t\tText: "', missed_segment['text'], '"', sep='')
289
- print('\t\tCategory:',
290
- missed_segment.get('category'))
 
291
  if 'probability' in missed_segment:
292
- print('\t\tProbability:',
293
- missed_segment['probability'])
294
 
295
  segments_to_submit.append({
296
  'segment': [missed_segment['start'], missed_segment['end']],
@@ -299,33 +357,37 @@ def main():
299
  })
300
 
301
  json_data = quote(json.dumps(segments_to_submit))
302
- print(
303
  f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
304
 
305
  # Incorrect segments (in database, but incorrectly classified)
306
  if incorrect_segments:
307
- print(' - Incorrect segments:')
308
  for i, incorrect_segment in enumerate(incorrect_segments, start=1):
309
- print(f'\t#{i}:', seconds_to_time(
310
  incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
311
 
312
- print('\t\tText: "', incorrect_segment['text'], '"', sep='')
313
- print('\t\tUUID:', incorrect_segment['uuid'])
314
- print('\t\tVotes:', incorrect_segment['votes'])
315
- print('\t\tViews:', incorrect_segment['views'])
316
- print('\t\tLocked:',
317
- incorrect_segment['locked'])
318
-
319
- print('\t\tCurrent Category:',
320
- incorrect_segment['category'])
321
- print('\t\tPredicted Category:',
322
- incorrect_segment['predicted'])
323
- print('\t\tProbabilities:')
 
 
 
 
324
  for item in incorrect_segment['scores']:
325
- print(
326
  f"\t\t\t{item['label']}: {item['score']}")
327
 
328
- print()
329
 
330
  except KeyboardInterrupt:
331
  pass
 
1
 
2
  from model import get_model_tokenizer_classifier, InferenceArguments
3
+ from utils import jaccard, safe_print
4
  from transformers import HfArgumentParser
5
+ from preprocess import get_words, clean_text
6
  from shared import GeneralArguments, DatasetArguments
7
  from predict import predict
8
  from segment import extract_segment, word_start, word_end, SegmentationArguments, add_labels_to_words
 
31
  }
32
  )
33
 
34
+ skip_missing: bool = field(
35
+ default=False,
36
+ metadata={
37
+ 'help': 'Whether to skip checking for missing segments. If False, predictions will be made.'
38
+ }
39
+ )
40
+ skip_incorrect: bool = field(
41
+ default=False,
42
+ metadata={
43
+ 'help': 'Whether to skip checking for incorrect segments. If False, classifications will be made on existing segments.'
44
+ }
45
+ )
46
+
47
 
48
  def attach_predictions_to_sponsor_segments(predictions, sponsor_segments):
49
  """Attach sponsor segments to closest prediction"""
 
59
  prediction['best_overlap'] = j
60
  prediction['best_sponsorship'] = sponsor_segment
61
 
62
+ return sponsor_segments
63
 
64
 
65
  def calculate_metrics(labelled_words, predictions):
 
143
 
144
  evaluation_args, dataset_args, segmentation_args, general_args = hf_parser.parse_args_into_dataclasses()
145
 
146
+ if evaluation_args.skip_missing and evaluation_args.skip_incorrect:
147
+ logger.error('ERROR: Nothing to do')
148
+ return
149
+
150
  # Load labelled data:
151
  final_path = os.path.join(
152
  dataset_args.data_dir, dataset_args.processed_file)
 
175
  if evaluation_args.max_videos is not None:
176
  video_ids = video_ids[:evaluation_args.max_videos]
177
 
178
+ out_metrics = []
179
 
180
+ all_metrics = {}
181
+ if not evaluation_args.skip_missing:
182
+ all_metrics['total_prediction_accuracy'] = 0
183
+ all_metrics['total_prediction_precision'] = 0
184
+ all_metrics['total_prediction_recall'] = 0
185
+ all_metrics['total_prediction_fscore'] = 0
186
 
187
+ if not evaluation_args.skip_incorrect:
188
+ all_metrics['classifier_segment_correct'] = 0
189
+ all_metrics['classifier_segment_count'] = 0
190
+
191
+ metric_count = 0
192
+
193
+ postfix_info = {}
194
 
195
  try:
196
  with tqdm(video_ids) as progress:
 
201
  if not words:
202
  continue
203
 
 
 
 
 
 
204
  # Get labels
205
  sponsor_segments = final_data.get(video_id)
 
 
 
 
 
206
 
207
+ # Reset previous
208
+ missed_segments = []
209
+ incorrect_segments = []
210
 
211
+ current_metrics = {
212
+ 'video_id': video_id
213
+ }
214
+ metric_count += 1
215
 
216
+ if not evaluation_args.skip_missing: # Make predictions
217
+ predictions = predict(video_id, model, tokenizer, segmentation_args,
218
+ classifier=classifier,
219
+ min_probability=evaluation_args.min_probability)
 
 
220
 
221
+ if sponsor_segments:
222
+ labelled_words = add_labels_to_words(
223
+ words, sponsor_segments)
224
 
225
+ current_metrics.update(
226
+ calculate_metrics(labelled_words, predictions))
 
227
 
228
+ all_metrics['total_prediction_accuracy'] += current_metrics['accuracy']
229
+ all_metrics['total_prediction_precision'] += current_metrics['precision']
230
+ all_metrics['total_prediction_recall'] += current_metrics['recall']
231
+ all_metrics['total_prediction_fscore'] += current_metrics['f-score']
232
+
233
+ # Just for display purposes
234
+ postfix_info.update({
235
+ 'accuracy': all_metrics['total_prediction_accuracy']/metric_count,
236
+ 'precision': all_metrics['total_prediction_precision']/metric_count,
237
+ 'recall': all_metrics['total_prediction_recall']/metric_count,
238
+ 'f-score': all_metrics['total_prediction_fscore']/metric_count,
239
+ })
240
+
241
+ sponsor_segments = attach_predictions_to_sponsor_segments(
242
+ predictions, sponsor_segments)
243
+
244
+ # Identify possible issues:
245
+ for prediction in predictions:
246
+ if prediction['best_sponsorship'] is not None:
247
+ continue
248
+
249
+ prediction_words = prediction.pop('words', [])
250
+
251
+ # Attach original text to missed segments
252
+ prediction['text'] = ' '.join(
253
+ x['text'] for x in prediction_words)
254
+ missed_segments.append(prediction)
255
+
256
+ else:
257
+ # Not in database (all segments missed)
258
+ missed_segments = predictions
259
+
260
+ if not evaluation_args.skip_incorrect and sponsor_segments:
261
+ # Check for incorrect segments using the classifier
262
 
263
  segments_to_check = []
264
  texts = [] # Texts to send through tokenizer
265
  for sponsor_segment in sponsor_segments:
266
  segment_words = extract_segment(
267
  words, sponsor_segment['start'], sponsor_segment['end'])
268
+ sponsor_segment['text'] = ' '.join(
269
+ x['text'] for x in segment_words)
270
 
271
+ duration = sponsor_segment['end'] - \
272
+ sponsor_segment['start']
273
+ wps = (len(segment_words) /
274
+ duration) if duration > 0 else 0
275
  if wps < 1.5:
276
  continue
277
 
 
280
  if sponsor_segment['locked']:
281
  continue
282
 
283
+ sponsor_segment['cleaned_text'] = clean_text(
284
+ sponsor_segment['text'])
285
  texts.append(sponsor_segment['cleaned_text'])
286
  segments_to_check.append(sponsor_segment)
287
 
288
+ if segments_to_check: # Some segments to check
289
 
290
  segments_scores = classifier(texts)
291
 
292
+ num_correct = 0
293
  for segment, scores in zip(segments_to_check, segments_scores):
294
+ all_metrics['classifier_segment_count'] += 1
295
+
296
  prediction = max(scores, key=lambda x: x['score'])
297
  predicted_category = prediction['label'].lower()
298
 
299
  if predicted_category == segment['category']:
300
+ num_correct += 1
301
  continue # Ignore correct segments
302
 
303
  segment.update({
 
307
 
308
  incorrect_segments.append(segment)
309
 
310
+ current_metrics['num_segments'] = len(
311
+ segments_to_check)
312
+ current_metrics['classified_correct'] = num_correct
313
+
314
+ all_metrics['classifier_segment_correct'] += num_correct
315
+
316
+ postfix_info['classifier_accuracy'] = all_metrics['classifier_segment_correct'] / \
317
+ all_metrics['classifier_segment_count']
318
+
319
+ out_metrics.append(current_metrics)
320
+ progress.set_postfix(postfix_info)
321
 
322
  if missed_segments or incorrect_segments:
 
 
 
 
 
323
 
324
  if evaluation_args.output_as_json:
325
  to_print = {'video_id': video_id}
 
330
  if incorrect_segments:
331
  to_print['incorrect'] = incorrect_segments
332
 
333
+ safe_print(json.dumps(to_print))
334
+
335
  else:
336
+ safe_print(
337
  f'Issues identified for {video_id} (#{video_index})')
338
  # Potentially missed segments (model predicted, but not in database)
339
  if missed_segments:
340
+ safe_print(' - Missed segments:')
341
  segments_to_submit = []
342
  for i, missed_segment in enumerate(missed_segments, start=1):
343
+ safe_print(f'\t#{i}:', seconds_to_time(
344
  missed_segment['start']), '-->', seconds_to_time(missed_segment['end']))
345
+ safe_print('\t\tText: "',
346
+ missed_segment['text'], '"', sep='')
347
+ safe_print('\t\tCategory:',
348
+ missed_segment.get('category'))
349
  if 'probability' in missed_segment:
350
+ safe_print('\t\tProbability:',
351
+ missed_segment['probability'])
352
 
353
  segments_to_submit.append({
354
  'segment': [missed_segment['start'], missed_segment['end']],
 
357
  })
358
 
359
  json_data = quote(json.dumps(segments_to_submit))
360
+ safe_print(
361
  f'\tSubmit: https://www.youtube.com/watch?v={video_id}#segments={json_data}')
362
 
363
  # Incorrect segments (in database, but incorrectly classified)
364
  if incorrect_segments:
365
+ safe_print(' - Incorrect segments:')
366
  for i, incorrect_segment in enumerate(incorrect_segments, start=1):
367
+ safe_print(f'\t#{i}:', seconds_to_time(
368
  incorrect_segment['start']), '-->', seconds_to_time(incorrect_segment['end']))
369
 
370
+ safe_print(
371
+ '\t\tText: "', incorrect_segment['text'], '"', sep='')
372
+ safe_print(
373
+ '\t\tUUID:', incorrect_segment['uuid'])
374
+ safe_print(
375
+ '\t\tVotes:', incorrect_segment['votes'])
376
+ safe_print(
377
+ '\t\tViews:', incorrect_segment['views'])
378
+ safe_print('\t\tLocked:',
379
+ incorrect_segment['locked'])
380
+
381
+ safe_print('\t\tCurrent Category:',
382
+ incorrect_segment['category'])
383
+ safe_print('\t\tPredicted Category:',
384
+ incorrect_segment['predicted'])
385
+ safe_print('\t\tProbabilities:')
386
  for item in incorrect_segment['scores']:
387
+ safe_print(
388
  f"\t\t\t{item['label']}: {item['score']}")
389
 
390
+ safe_print()
391
 
392
  except KeyboardInterrupt:
393
  pass