Spaces:
Running
Running
Joshua Lochner
commited on
Commit
·
83dc695
1
Parent(s):
4b4c9f0
Fix `num_tokens` key in words
Browse files- src/segment.py +14 -5
src/segment.py
CHANGED
@@ -58,8 +58,12 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
58 |
cleaned_words_list.append(w['cleaned'])
|
59 |
|
60 |
# Get lengths of tokenized words
|
61 |
-
num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
|
|
|
|
|
62 |
for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
|
|
|
|
|
63 |
# Add new segment
|
64 |
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
65 |
first_pass_segments.append([word])
|
@@ -80,7 +84,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
80 |
|
81 |
for word in segment:
|
82 |
new_seg = current_segment_num_tokens + \
|
83 |
-
num_tokens >= max_q_size
|
84 |
if new_seg:
|
85 |
# Adding this token would make it have too many tokens
|
86 |
# We save this batch and create new
|
@@ -88,7 +92,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
88 |
|
89 |
# Add tokens to current segment
|
90 |
current_segment.append(word)
|
91 |
-
current_segment_num_tokens += num_tokens
|
92 |
|
93 |
if not new_seg:
|
94 |
continue
|
@@ -96,7 +100,7 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
96 |
# Just created a new segment, so we remove until we only have buffer_size tokens
|
97 |
last_index = 0
|
98 |
while current_segment_num_tokens > buffer_size and current_segment:
|
99 |
-
current_segment_num_tokens -=
|
100 |
last_index += 1
|
101 |
|
102 |
current_segment = current_segment[last_index:]
|
@@ -104,6 +108,11 @@ def generate_segments(words, tokenizer, segmentation_args):
|
|
104 |
if current_segment: # Add remaining segment
|
105 |
second_pass_segments.append(current_segment)
|
106 |
|
|
|
|
|
|
|
|
|
|
|
107 |
return second_pass_segments
|
108 |
|
109 |
|
@@ -111,7 +120,7 @@ def extract_segment(words, start, end, map_function=None):
|
|
111 |
"""Extracts all words with time in [start, end]"""
|
112 |
|
113 |
a = binary_search_below(words, 0, len(words) - 1, start)
|
114 |
-
b = min(binary_search_above(words, 0, len(words) - 1, end) + 1
|
115 |
|
116 |
to_transform = map_function is not None and callable(map_function)
|
117 |
|
|
|
58 |
cleaned_words_list.append(w['cleaned'])
|
59 |
|
60 |
# Get lengths of tokenized words
|
61 |
+
num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
|
62 |
+
truncation=True, return_attention_mask=False, return_length=True).length
|
63 |
+
|
64 |
for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
|
65 |
+
word['num_tokens'] = num_tokens
|
66 |
+
|
67 |
# Add new segment
|
68 |
if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
|
69 |
first_pass_segments.append([word])
|
|
|
84 |
|
85 |
for word in segment:
|
86 |
new_seg = current_segment_num_tokens + \
|
87 |
+
word['num_tokens'] >= max_q_size
|
88 |
if new_seg:
|
89 |
# Adding this token would make it have too many tokens
|
90 |
# We save this batch and create new
|
|
|
92 |
|
93 |
# Add tokens to current segment
|
94 |
current_segment.append(word)
|
95 |
+
current_segment_num_tokens += word['num_tokens']
|
96 |
|
97 |
if not new_seg:
|
98 |
continue
|
|
|
100 |
# Just created a new segment, so we remove until we only have buffer_size tokens
|
101 |
last_index = 0
|
102 |
while current_segment_num_tokens > buffer_size and current_segment:
|
103 |
+
current_segment_num_tokens -= current_segment[last_index]['num_tokens']
|
104 |
last_index += 1
|
105 |
|
106 |
current_segment = current_segment[last_index:]
|
|
|
108 |
if current_segment: # Add remaining segment
|
109 |
second_pass_segments.append(current_segment)
|
110 |
|
111 |
+
# Cleaning up, delete 'num_tokens' from each word
|
112 |
+
# for segment in second_pass_segments:
|
113 |
+
for word in words:
|
114 |
+
word.pop('num_tokens', None)
|
115 |
+
|
116 |
return second_pass_segments
|
117 |
|
118 |
|
|
|
120 |
"""Extracts all words with time in [start, end]"""
|
121 |
|
122 |
a = binary_search_below(words, 0, len(words) - 1, start)
|
123 |
+
b = min(binary_search_above(words, 0, len(words) - 1, end) + 1, len(words))
|
124 |
|
125 |
to_transform = map_function is not None and callable(map_function)
|
126 |
|