nithinraok commited on
Commit
329912e
·
verified ·
1 Parent(s): 3ae07da

Create nemo_align.py

Browse files
Files changed (1) hide show
  1. nemo_align.py +522 -0
nemo_align.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from nemo.collections.asr.models import EncDecHybridRNNTCTCModel
2
+ from dataclasses import dataclass, field
3
+ from typing import List, Union
4
+ import torch
5
+ from nemo.utils import logging
6
+ from pathlib import Path
7
+ from viterbi_decoding import viterbi_decoding
8
+
9
+ BLANK_TOKEN = "<b>"
10
+
11
+ SPACE_TOKEN = "<space>"
12
+
13
+ V_NEGATIVE_NUM = -3.4e38
14
+
15
+
16
+ @dataclass
17
+ class Token:
18
+ text: str = None
19
+ text_cased: str = None
20
+ s_start: int = None
21
+ s_end: int = None
22
+ t_start: float = None
23
+ t_end: float = None
24
+
25
+
26
+ @dataclass
27
+ class Word:
28
+ text: str = None
29
+ s_start: int = None
30
+ s_end: int = None
31
+ t_start: float = None
32
+ t_end: float = None
33
+ tokens: List[Token] = field(default_factory=list)
34
+
35
+
36
+ @dataclass
37
+ class Segment:
38
+ text: str = None
39
+ s_start: int = None
40
+ s_end: int = None
41
+ t_start: float = None
42
+ t_end: float = None
43
+ words_and_tokens: List[Union[Word, Token]] = field(default_factory=list)
44
+
45
+
46
+ @dataclass
47
+ class Utterance:
48
+ token_ids_with_blanks: List[int] = field(default_factory=list)
49
+ segments_and_tokens: List[Union[Segment, Token]] = field(default_factory=list)
50
+ text: str = None
51
+ pred_text: str = None
52
+ audio_filepath: str = None
53
+ utt_id: str = None
54
+ saved_output_files: dict = field(default_factory=dict)
55
+
56
+ def is_sub_or_superscript_pair(ref_text, text):
57
+ """returns True if ref_text is a subscript or superscript version of text"""
58
+ sub_or_superscript_to_num = {
59
+ "⁰": "0",
60
+ "¹": "1",
61
+ "²": "2",
62
+ "³": "3",
63
+ "⁴": "4",
64
+ "⁵": "5",
65
+ "⁶": "6",
66
+ "⁷": "7",
67
+ "⁸": "8",
68
+ "⁹": "9",
69
+ "₀": "0",
70
+ "₁": "1",
71
+ "₂": "2",
72
+ "₃": "3",
73
+ "₄": "4",
74
+ "₅": "5",
75
+ "₆": "6",
76
+ "₇": "7",
77
+ "₈": "8",
78
+ "₉": "9",
79
+ }
80
+
81
+ if text in sub_or_superscript_to_num:
82
+ if sub_or_superscript_to_num[text] == ref_text:
83
+ return True
84
+ return False
85
+
86
+ def restore_token_case(word, word_tokens):
87
+
88
+ # remove repeated "▁" and "_" from word as that is what the tokenizer will do
89
+ while "▁▁" in word:
90
+ word = word.replace("▁▁", "▁")
91
+
92
+ while "__" in word:
93
+ word = word.replace("__", "_")
94
+
95
+ word_tokens_cased = []
96
+ word_char_pointer = 0
97
+
98
+ for token in word_tokens:
99
+ token_cased = ""
100
+
101
+ for token_char in token:
102
+ if token_char == word[word_char_pointer]:
103
+ token_cased += token_char
104
+ word_char_pointer += 1
105
+
106
+ else:
107
+ if token_char.upper() == word[word_char_pointer] or is_sub_or_superscript_pair(
108
+ token_char, word[word_char_pointer]
109
+ ):
110
+ token_cased += token_char.upper()
111
+ word_char_pointer += 1
112
+ else:
113
+ if token_char == "▁" or token_char == "_":
114
+ if word[word_char_pointer] == "▁" or word[word_char_pointer] == "_":
115
+ token_cased += token_char
116
+ word_char_pointer += 1
117
+ elif word_char_pointer == 0:
118
+ token_cased += token_char
119
+
120
+ else:
121
+ raise RuntimeError(
122
+ f"Unexpected error - failed to recover capitalization of tokens for word {word}"
123
+ )
124
+
125
+ word_tokens_cased.append(token_cased)
126
+
127
+ return word_tokens_cased
128
+
129
+ def get_utt_obj(
130
+ text, model, separator, T, audio_filepath, utt_id,
131
+ ):
132
+ """
133
+ Function to create an Utterance object and add all necessary information to it except
134
+ for timings of the segments / words / tokens according to the alignment - that will
135
+ be done later in a different function, after the alignment is done.
136
+
137
+ The Utterance object has a list segments_and_tokens which contains Segment objects and
138
+ Token objects (for blank tokens in between segments).
139
+ Within the Segment objects, there is a list words_and_tokens which contains Word objects and
140
+ Token objects (for blank tokens in between words).
141
+ Within the Word objects, there is a list tokens tokens which contains Token objects for
142
+ blank and non-blank tokens.
143
+ We will be building up these lists in this function. This data structure will then be useful for
144
+ generating the various output files that we wish to save.
145
+ """
146
+
147
+ if not separator: # if separator is not defined - treat the whole text as one segment
148
+ segments = [text]
149
+ else:
150
+ segments = text.split(separator)
151
+
152
+ # remove any spaces at start and end of segments
153
+ segments = [seg.strip() for seg in segments]
154
+ # remove any empty segments
155
+ segments = [seg for seg in segments if len(seg) > 0]
156
+
157
+ utt = Utterance(text=text, audio_filepath=audio_filepath, utt_id=utt_id,)
158
+
159
+ # build up lists: token_ids_with_blanks, segments_and_tokens.
160
+ # The code for these is different depending on whether we use char-based tokens or not
161
+ if hasattr(model, 'tokenizer'):
162
+ if hasattr(model, 'blank_id'):
163
+ BLANK_ID = model.blank_id
164
+ else:
165
+ BLANK_ID = len(model.tokenizer.vocab) # TODO: check
166
+
167
+ utt.token_ids_with_blanks = [BLANK_ID]
168
+
169
+ # check for text being 0 length
170
+ if len(text) == 0:
171
+ return utt
172
+
173
+ # check for # tokens + token repetitions being > T
174
+ all_tokens = model.tokenizer.text_to_ids(text)
175
+ n_token_repetitions = 0
176
+ for i_tok in range(1, len(all_tokens)):
177
+ if all_tokens[i_tok] == all_tokens[i_tok - 1]:
178
+ n_token_repetitions += 1
179
+
180
+ if len(all_tokens) + n_token_repetitions > T:
181
+ logging.info(
182
+ f"Utterance {utt_id} has too many tokens compared to the audio file duration."
183
+ " Will not generate output alignment files for this utterance."
184
+ )
185
+ return utt
186
+
187
+ # build up data structures containing segments/words/tokens
188
+ utt.segments_and_tokens.append(Token(text=BLANK_TOKEN, text_cased=BLANK_TOKEN, s_start=0, s_end=0,))
189
+
190
+ segment_s_pointer = 1 # first segment will start at s=1 because s=0 is a blank
191
+ word_s_pointer = 1 # first word will start at s=1 because s=0 is a blank
192
+
193
+ for segment in segments:
194
+ # add the segment to segment_info and increment the segment_s_pointer
195
+ segment_tokens = model.tokenizer.text_to_tokens(segment)
196
+ utt.segments_and_tokens.append(
197
+ Segment(
198
+ text=segment,
199
+ s_start=segment_s_pointer,
200
+ # segment_tokens do not contain blanks => need to muliply by 2
201
+ # s_end needs to be the index of the final token (including blanks) of the current segment:
202
+ # segment_s_pointer + len(segment_tokens) * 2 is the index of the first token of the next segment =>
203
+ # => need to subtract 2
204
+ s_end=segment_s_pointer + len(segment_tokens) * 2 - 2,
205
+ )
206
+ )
207
+ segment_s_pointer += (
208
+ len(segment_tokens) * 2
209
+ ) # multiply by 2 to account for blanks (which are not present in segment_tokens)
210
+
211
+ words = segment.split(" ") # we define words to be space-separated sub-strings
212
+ for word_i, word in enumerate(words):
213
+
214
+ word_tokens = model.tokenizer.text_to_tokens(word)
215
+ word_token_ids = model.tokenizer.text_to_ids(word)
216
+ word_tokens_cased = restore_token_case(word, word_tokens)
217
+
218
+ # add the word to word_info and increment the word_s_pointer
219
+ utt.segments_and_tokens[-1].words_and_tokens.append(
220
+ # word_tokens do not contain blanks => need to muliply by 2
221
+ # s_end needs to be the index of the final token (including blanks) of the current word:
222
+ # word_s_pointer + len(word_tokens) * 2 is the index of the first token of the next word =>
223
+ # => need to subtract 2
224
+ Word(text=word, s_start=word_s_pointer, s_end=word_s_pointer + len(word_tokens) * 2 - 2)
225
+ )
226
+ word_s_pointer += (
227
+ len(word_tokens) * 2
228
+ ) # multiply by 2 to account for blanks (which are not present in word_tokens)
229
+
230
+ for token_i, (token, token_id, token_cased) in enumerate(
231
+ zip(word_tokens, word_token_ids, word_tokens_cased)
232
+ ):
233
+ # add the text tokens and the blanks in between them
234
+ # to our token-based variables
235
+ utt.token_ids_with_blanks.extend([token_id, BLANK_ID])
236
+ # adding Token object for non-blank token
237
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
238
+ Token(
239
+ text=token,
240
+ text_cased=token_cased,
241
+ # utt.token_ids_with_blanks has the form [...., <this non-blank token>, <blank>] =>
242
+ # => if do len(utt.token_ids_with_blanks) - 1 you get the index of the final <blank>
243
+ # => we want to do len(utt.token_ids_with_blanks) - 2 to get the index of <this non-blank token>
244
+ s_start=len(utt.token_ids_with_blanks) - 2,
245
+ # s_end is same as s_start since the token only occupies one element in the list
246
+ s_end=len(utt.token_ids_with_blanks) - 2,
247
+ )
248
+ )
249
+
250
+ # adding Token object for blank tokens in between the tokens of the word
251
+ # (ie do not add another blank if you have reached the end)
252
+ if token_i < len(word_tokens) - 1:
253
+ utt.segments_and_tokens[-1].words_and_tokens[-1].tokens.append(
254
+ Token(
255
+ text=BLANK_TOKEN,
256
+ text_cased=BLANK_TOKEN,
257
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
258
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
259
+ s_start=len(utt.token_ids_with_blanks) - 1,
260
+ # s_end is same as s_start since the token only occupies one element in the list
261
+ s_end=len(utt.token_ids_with_blanks) - 1,
262
+ )
263
+ )
264
+
265
+ # add a Token object for blanks in between words in this segment
266
+ # (but only *in between* - do not add the token if it is after the final word)
267
+ if word_i < len(words) - 1:
268
+ utt.segments_and_tokens[-1].words_and_tokens.append(
269
+ Token(
270
+ text=BLANK_TOKEN,
271
+ text_cased=BLANK_TOKEN,
272
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
273
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
274
+ s_start=len(utt.token_ids_with_blanks) - 1,
275
+ # s_end is same as s_start since the token only occupies one element in the list
276
+ s_end=len(utt.token_ids_with_blanks) - 1,
277
+ )
278
+ )
279
+
280
+ # add the blank token in between segments/after the final segment
281
+ utt.segments_and_tokens.append(
282
+ Token(
283
+ text=BLANK_TOKEN,
284
+ text_cased=BLANK_TOKEN,
285
+ # utt.token_ids_with_blanks has the form [...., <this blank token>] =>
286
+ # => if do len(utt.token_ids_with_blanks) -1 you get the index of this <blank>
287
+ s_start=len(utt.token_ids_with_blanks) - 1,
288
+ # s_end is same as s_start since the token only occupies one element in the list
289
+ s_end=len(utt.token_ids_with_blanks) - 1,
290
+ )
291
+ )
292
+
293
+ return utt
294
+
295
+ def _get_utt_id(audio_filepath, audio_filepath_parts_in_utt_id):
296
+ fp_parts = Path(audio_filepath).parts[-audio_filepath_parts_in_utt_id:]
297
+ utt_id = Path("_".join(fp_parts)).stem
298
+ utt_id = utt_id.replace(" ", "-") # replace any spaces in the filepath with dashes
299
+ return utt_id
300
+
301
+ def add_t_start_end_to_utt_obj(utt_obj, alignment_utt, output_timestep_duration):
302
+ """
303
+ Function to add t_start and t_end (representing time in seconds) to the Utterance object utt_obj.
304
+ Args:
305
+ utt_obj: Utterance object to which we will add t_start and t_end for its
306
+ constituent segments/words/tokens.
307
+ alignment_utt: a list of ints indicating which token does the alignment pass through at each
308
+ timestep (will take the form [0, 0, 1, 1, ..., <num of tokens including blanks in uterance>]).
309
+ output_timestep_duration: a float indicating the duration of a single output timestep from
310
+ the ASR Model.
311
+
312
+ Returns:
313
+ utt_obj: updated Utterance object.
314
+ """
315
+
316
+ # General idea for the algorithm of how we add t_start and t_end
317
+ # the timestep where a token s starts is the location of the first appearance of s_start in alignment_utt
318
+ # the timestep where a token s ends is the location of the final appearance of s_end in alignment_utt
319
+ # We will make dictionaries num_to_first_alignment_appearance and
320
+ # num_to_last_appearance and use that to update all of
321
+ # the t_start and t_end values in utt_obj.
322
+ # We will put t_start = t_end = -1 for tokens that are skipped (should only be blanks)
323
+
324
+ num_to_first_alignment_appearance = dict()
325
+ num_to_last_alignment_appearance = dict()
326
+
327
+ prev_s = -1 # use prev_s to keep track of when the s changes
328
+ for t, s in enumerate(alignment_utt):
329
+ if s > prev_s:
330
+ num_to_first_alignment_appearance[s] = t
331
+
332
+ if prev_s >= 0: # dont record prev_s = -1
333
+ num_to_last_alignment_appearance[prev_s] = t - 1
334
+ prev_s = s
335
+ # add last appearance of the final s
336
+ num_to_last_alignment_appearance[prev_s] = len(alignment_utt) - 1
337
+
338
+ # update all the t_start and t_end in utt_obj
339
+ for segment_or_token in utt_obj.segments_and_tokens:
340
+ if type(segment_or_token) is Segment:
341
+ segment = segment_or_token
342
+ segment.t_start = num_to_first_alignment_appearance[segment.s_start] * output_timestep_duration
343
+ segment.t_end = (num_to_last_alignment_appearance[segment.s_end] + 1) * output_timestep_duration
344
+
345
+ for word_or_token in segment.words_and_tokens:
346
+ if type(word_or_token) is Word:
347
+ word = word_or_token
348
+ word.t_start = num_to_first_alignment_appearance[word.s_start] * output_timestep_duration
349
+ word.t_end = (num_to_last_alignment_appearance[word.s_end] + 1) * output_timestep_duration
350
+
351
+ for token in word.tokens:
352
+ if token.s_start in num_to_first_alignment_appearance:
353
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
354
+ else:
355
+ token.t_start = -1
356
+
357
+ if token.s_end in num_to_last_alignment_appearance:
358
+ token.t_end = (
359
+ num_to_last_alignment_appearance[token.s_end] + 1
360
+ ) * output_timestep_duration
361
+ else:
362
+ token.t_end = -1
363
+ else:
364
+ token = word_or_token
365
+ if token.s_start in num_to_first_alignment_appearance:
366
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
367
+ else:
368
+ token.t_start = -1
369
+
370
+ if token.s_end in num_to_last_alignment_appearance:
371
+ token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
372
+ else:
373
+ token.t_end = -1
374
+
375
+ else:
376
+ token = segment_or_token
377
+ if token.s_start in num_to_first_alignment_appearance:
378
+ token.t_start = num_to_first_alignment_appearance[token.s_start] * output_timestep_duration
379
+ else:
380
+ token.t_start = -1
381
+
382
+ if token.s_end in num_to_last_alignment_appearance:
383
+ token.t_end = (num_to_last_alignment_appearance[token.s_end] + 1) * output_timestep_duration
384
+ else:
385
+ token.t_end = -1
386
+
387
+ return utt_obj
388
+
389
+ def get_word_timings(
390
+ alignment_level, utt_obj,
391
+ ):
392
+ boundary_info_utt = []
393
+ for segment_or_token in utt_obj.segments_and_tokens:
394
+ if type(segment_or_token) is Segment:
395
+ segment = segment_or_token
396
+ for word_or_token in segment.words_and_tokens:
397
+ if type(word_or_token) is Word:
398
+ word = word_or_token
399
+ if alignment_level == "words":
400
+ boundary_info_utt.append(word)
401
+
402
+ word_timestamps=[]
403
+ for boundary_info_ in boundary_info_utt: # loop over every token/word/segment
404
+
405
+ # skip if t_start = t_end = negative number because we used it as a marker to skip some blank tokens
406
+ if not (boundary_info_.t_start < 0 or boundary_info_.t_end < 0):
407
+ text = boundary_info_.text
408
+ start_time = boundary_info_.t_start
409
+ end_time = boundary_info_.t_end
410
+
411
+ text = text.replace(" ", SPACE_TOKEN)
412
+ word_timestamps.append((text, start_time, end_time))
413
+
414
+ return word_timestamps
415
+
416
+ def get_start_end_for_segments(word_timestamps):
417
+ segment_timestamps=[]
418
+ word_list = []
419
+ beginning = None
420
+ for word, start, end in word_timestamps:
421
+ if beginning is None:
422
+ beginning = start
423
+ word = word.capitalize()
424
+ word_list.append(word)
425
+ if word.endswith('.') or word.endswith('?') or word.endswith('!'):
426
+ segment = ' '.join(word_list)
427
+ segment_timestamps.append((segment, beginning, end))
428
+ beginning = None
429
+ word_list = []
430
+
431
+
432
+ segment = ' '.join(word_list)
433
+ segment_timestamps.append((segment, beginning, end))
434
+
435
+ return segment_timestamps
436
+
437
+
438
+ def align_tdt_to_ctc_timestamps(tdt_txt, model, audio_filepath):
439
+ if isinstance(model, EncDecHybridRNNTCTCModel):
440
+ model.change_decoding_strategy(decoder_type="ctc")
441
+ else:
442
+ raise ValueError("Currently supporting hybrid models")
443
+
444
+ with torch.cuda.amp.autocast(enabled=True, dtype=torch.bfloat16):
445
+ with torch.inference_mode():
446
+ hypotheses = model.transcribe([audio_filepath], return_hypotheses=True, batch_size=1)
447
+
448
+ if type(hypotheses) == tuple and len(hypotheses) == 2:
449
+ hypotheses = hypotheses[0]
450
+
451
+ log_probs_list_batch = [hypotheses[0].y_sequence]
452
+ T_list_batch = [hypotheses[0].y_sequence.shape[0]]
453
+ ctc_pred_text = hypotheses[0].text if tdt_txt is not None else tdt_txt
454
+
455
+ utt_obj = get_utt_obj(
456
+ ctc_pred_text,
457
+ model,
458
+ None,
459
+ T_list_batch[0],
460
+ audio_filepath,
461
+ _get_utt_id(audio_filepath, 1),
462
+ )
463
+
464
+ utt_obj.pred_text = ctc_pred_text
465
+
466
+ y_list_batch = [utt_obj.token_ids_with_blanks]
467
+ U_list_batch = [len(utt_obj.token_ids_with_blanks)]
468
+
469
+ if hasattr(model, 'tokenizer'):
470
+ V = len(model.tokenizer.vocab) + 1
471
+ else:
472
+ V = len(model.decoder.vocabulary) + 1
473
+
474
+ # turn log_probs, y, T, U into dense tensors for fast computation during Viterbi decoding
475
+ T_max = max(T_list_batch)
476
+ U_max = max(U_list_batch)
477
+ # V = the number of tokens in the vocabulary + 1 for the blank token.
478
+ if hasattr(model, 'tokenizer'):
479
+ V = len(model.tokenizer.vocab) + 1
480
+ else:
481
+ V = len(model.decoder.vocabulary) + 1
482
+ T_batch = torch.tensor(T_list_batch)
483
+ U_batch = torch.tensor(U_list_batch)
484
+
485
+ # make log_probs_batch tensor of shape (B x T_max x V)
486
+ log_probs_batch = V_NEGATIVE_NUM * torch.ones((1, T_max, V))
487
+ for b, log_probs_utt in enumerate(log_probs_list_batch):
488
+ t = log_probs_utt.shape[0]
489
+ log_probs_batch[b, :t, :] = log_probs_utt
490
+
491
+ y_batch = V * torch.ones((1, U_max), dtype=torch.int64)
492
+ for b, y_utt in enumerate(y_list_batch):
493
+ U_utt = U_batch[b]
494
+ y_batch[b, :U_utt] = torch.tensor(y_utt)
495
+
496
+ model_downsample_factor = 8
497
+ output_timestep_duration = (
498
+ model.preprocessor.featurizer.hop_length * model_downsample_factor / model.cfg.preprocessor.sample_rate
499
+ )
500
+
501
+ alignments_batch = viterbi_decoding(log_probs_batch, y_batch, T_batch, U_batch, torch.device('cuda'))
502
+
503
+
504
+ utt_obj = add_t_start_end_to_utt_obj(utt_obj, alignments_batch[0], output_timestep_duration)
505
+
506
+ word_timestamps = get_word_timings("words", utt_obj=utt_obj)
507
+
508
+ segmet_timestamps = get_start_end_for_segments(word_timestamps)
509
+
510
+ return segmet_timestamps
511
+
512
+
513
+ # def main():
514
+ # # model = 'nvidia/parakeet-tdt_ctc-1.1b.nemo'
515
+ # # from nemo.collections.asr.models import ASRModel
516
+ # # asr_model = ASRModel.from_pretrained(model).to('cuda')
517
+ # # asr_model.eval()
518
+ # # Segments = align_tdt_to_ctc_timestamps(None, asr_model, 'processed_file.flac')
519
+
520
+
521
+ # if __name__ == '__main__':
522
+ # main()