Joshua Lochner commited on
Commit
83dc695
·
1 Parent(s): 4b4c9f0

Fix `num_tokens` key in words

Browse files
Files changed (1) hide show
  1. src/segment.py +14 -5
src/segment.py CHANGED
@@ -58,8 +58,12 @@ def generate_segments(words, tokenizer, segmentation_args):
58
  cleaned_words_list.append(w['cleaned'])
59
 
60
  # Get lengths of tokenized words
61
- num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False, truncation=True, return_attention_mask=False, return_length=True).length
 
 
62
  for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
 
 
63
  # Add new segment
64
  if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
65
  first_pass_segments.append([word])
@@ -80,7 +84,7 @@ def generate_segments(words, tokenizer, segmentation_args):
80
 
81
  for word in segment:
82
  new_seg = current_segment_num_tokens + \
83
- num_tokens >= max_q_size
84
  if new_seg:
85
  # Adding this token would make it have too many tokens
86
  # We save this batch and create new
@@ -88,7 +92,7 @@ def generate_segments(words, tokenizer, segmentation_args):
88
 
89
  # Add tokens to current segment
90
  current_segment.append(word)
91
- current_segment_num_tokens += num_tokens
92
 
93
  if not new_seg:
94
  continue
@@ -96,7 +100,7 @@ def generate_segments(words, tokenizer, segmentation_args):
96
  # Just created a new segment, so we remove until we only have buffer_size tokens
97
  last_index = 0
98
  while current_segment_num_tokens > buffer_size and current_segment:
99
- current_segment_num_tokens -= num_tokens_list[last_index]
100
  last_index += 1
101
 
102
  current_segment = current_segment[last_index:]
@@ -104,6 +108,11 @@ def generate_segments(words, tokenizer, segmentation_args):
104
  if current_segment: # Add remaining segment
105
  second_pass_segments.append(current_segment)
106
 
 
 
 
 
 
107
  return second_pass_segments
108
 
109
 
@@ -111,7 +120,7 @@ def extract_segment(words, start, end, map_function=None):
111
  """Extracts all words with time in [start, end]"""
112
 
113
  a = binary_search_below(words, 0, len(words) - 1, start)
114
- b = min(binary_search_above(words, 0, len(words) - 1, end) + 1 , len(words))
115
 
116
  to_transform = map_function is not None and callable(map_function)
117
 
 
58
  cleaned_words_list.append(w['cleaned'])
59
 
60
  # Get lengths of tokenized words
61
+ num_tokens_list = tokenizer(cleaned_words_list, add_special_tokens=False,
62
+ truncation=True, return_attention_mask=False, return_length=True).length
63
+
64
  for index, (word, num_tokens) in enumerate(zip(words, num_tokens_list)):
65
+ word['num_tokens'] = num_tokens
66
+
67
  # Add new segment
68
  if index == 0 or word_start(words[index]) - word_end(words[index-1]) >= segmentation_args.pause_threshold:
69
  first_pass_segments.append([word])
 
84
 
85
  for word in segment:
86
  new_seg = current_segment_num_tokens + \
87
+ word['num_tokens'] >= max_q_size
88
  if new_seg:
89
  # Adding this token would make it have too many tokens
90
  # We save this batch and create new
 
92
 
93
  # Add tokens to current segment
94
  current_segment.append(word)
95
+ current_segment_num_tokens += word['num_tokens']
96
 
97
  if not new_seg:
98
  continue
 
100
  # Just created a new segment, so we remove until we only have buffer_size tokens
101
  last_index = 0
102
  while current_segment_num_tokens > buffer_size and current_segment:
103
+ current_segment_num_tokens -= current_segment[last_index]['num_tokens']
104
  last_index += 1
105
 
106
  current_segment = current_segment[last_index:]
 
108
  if current_segment: # Add remaining segment
109
  second_pass_segments.append(current_segment)
110
 
111
+ # Cleaning up, delete 'num_tokens' from each word
112
+ # for segment in second_pass_segments:
113
+ for word in words:
114
+ word.pop('num_tokens', None)
115
+
116
  return second_pass_segments
117
 
118
 
 
120
  """Extracts all words with time in [start, end]"""
121
 
122
  a = binary_search_below(words, 0, len(words) - 1, start)
123
+ b = min(binary_search_above(words, 0, len(words) - 1, end) + 1, len(words))
124
 
125
  to_transform = map_function is not None and callable(map_function)
126