Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
69fe24d
1
Parent(s):
c2ccf6d
Temporarily disable filtering of predictions using classifier
Browse files- src/predict.py +15 -10
src/predict.py
CHANGED
|
@@ -23,14 +23,17 @@ from dataclasses import dataclass, field
|
|
| 23 |
from shared import device
|
| 24 |
import logging
|
| 25 |
|
|
|
|
| 26 |
|
| 27 |
-
def seconds_to_time(seconds):
|
| 28 |
fractional = round(seconds % 1, 3)
|
| 29 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|
| 30 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
| 31 |
m, s = divmod(remainder, 60)
|
| 32 |
-
|
| 33 |
-
|
|
|
|
|
|
|
| 34 |
|
| 35 |
@dataclass
|
| 36 |
class TrainingOutputArguments:
|
|
@@ -136,7 +139,7 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
| 136 |
segmentation_args
|
| 137 |
)
|
| 138 |
|
| 139 |
-
predictions =
|
| 140 |
|
| 141 |
# Add words back to time_ranges
|
| 142 |
for prediction in predictions:
|
|
@@ -144,8 +147,9 @@ def predict(video_id, model, tokenizer, segmentation_args, words=None, classifie
|
|
| 144 |
prediction['words'] = extract_segment(
|
| 145 |
words, prediction['start'], prediction['end'])
|
| 146 |
|
| 147 |
-
|
| 148 |
-
|
|
|
|
| 149 |
|
| 150 |
return predictions
|
| 151 |
|
|
@@ -188,7 +192,7 @@ def predict_sponsor_matches(text, model, tokenizer):
|
|
| 188 |
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
| 189 |
|
| 190 |
|
| 191 |
-
def
|
| 192 |
predicted_time_ranges = []
|
| 193 |
|
| 194 |
# TODO pass to model simultaneously, not in for loop
|
|
@@ -234,10 +238,11 @@ def segments_to_prediction_times(segments, model, tokenizer):
|
|
| 234 |
end_time = range['end']
|
| 235 |
|
| 236 |
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
| 237 |
-
start_time <= prev_prediction['end'] <= end_time or
|
| 238 |
-
prev_prediction['end'] <= MERGE_TIME_WITHIN
|
| 239 |
):
|
| 240 |
-
# Ending time of last segment is in this segment or
|
|
|
|
| 241 |
final_predicted_time_ranges[-1]['end'] = end_time
|
| 242 |
|
| 243 |
else: # No overlap, is a new prediction
|
|
|
|
| 23 |
from shared import device
|
| 24 |
import logging
|
| 25 |
|
| 26 |
+
import re
|
| 27 |
|
| 28 |
+
def seconds_to_time(seconds, remove_leading_zeroes=False):
|
| 29 |
fractional = round(seconds % 1, 3)
|
| 30 |
fractional = '' if fractional == 0 else str(fractional)[1:]
|
| 31 |
h, remainder = divmod(abs(int(seconds)), 3600)
|
| 32 |
m, s = divmod(remainder, 60)
|
| 33 |
+
hms = f'{h:02}:{m:02}:{s:02}'
|
| 34 |
+
if remove_leading_zeroes:
|
| 35 |
+
hms = re.sub(r'^0(?:0:0?)?', '', hms)
|
| 36 |
+
return f"{'-' if seconds < 0 else ''}{hms}{fractional}"
|
| 37 |
|
| 38 |
@dataclass
|
| 39 |
class TrainingOutputArguments:
|
|
|
|
| 139 |
segmentation_args
|
| 140 |
)
|
| 141 |
|
| 142 |
+
predictions = segments_to_predictions(segments, model, tokenizer)
|
| 143 |
|
| 144 |
# Add words back to time_ranges
|
| 145 |
for prediction in predictions:
|
|
|
|
| 147 |
prediction['words'] = extract_segment(
|
| 148 |
words, prediction['start'], prediction['end'])
|
| 149 |
|
| 150 |
+
# TODO add back
|
| 151 |
+
# if classifier_args is not None:
|
| 152 |
+
# predictions = filter_predictions(predictions, classifier_args)
|
| 153 |
|
| 154 |
return predictions
|
| 155 |
|
|
|
|
| 192 |
return re_findall(SPONSOR_MATCH_RE, sponsorship_text)
|
| 193 |
|
| 194 |
|
| 195 |
+
def segments_to_predictions(segments, model, tokenizer):
|
| 196 |
predicted_time_ranges = []
|
| 197 |
|
| 198 |
# TODO pass to model simultaneously, not in for loop
|
|
|
|
| 238 |
end_time = range['end']
|
| 239 |
|
| 240 |
if prev_prediction is not None and range['category'] == prev_prediction['category'] and (
|
| 241 |
+
start_time <= prev_prediction['end'] <= end_time or \
|
| 242 |
+
start_time - prev_prediction['end'] <= MERGE_TIME_WITHIN
|
| 243 |
):
|
| 244 |
+
# Ending time of last segment is in this segment or within the merge threshold,
|
| 245 |
+
# so we extend last prediction range
|
| 246 |
final_predicted_time_ranges[-1]['end'] = end_time
|
| 247 |
|
| 248 |
else: # No overlap, is a new prediction
|