Refactor pad and merge timestamps into one function
Browse filesThis also fixes a bunch of issues regarding when the timestamps
should be merged.
- src/segments.py +47 -0
- src/vad.py +6 -66
- tests/segments_test.py +48 -0
    	
        src/segments.py
    ADDED
    
    | @@ -0,0 +1,47 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from typing import Any, Dict, List
         | 
| 2 | 
            +
             | 
| 3 | 
            +
            import copy
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            def merge_timestamps(timestamps: List[Dict[str, Any]], merge_window: float = 5, max_merge_size: float = 30, padding_left: float = 1, padding_right: float = 1):
         | 
| 6 | 
            +
                result = []
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                if len(timestamps) == 0:
         | 
| 9 | 
            +
                    return result
         | 
| 10 | 
            +
             | 
| 11 | 
            +
                processed_time = 0
         | 
| 12 | 
            +
                current_segment = None
         | 
| 13 | 
            +
             | 
| 14 | 
            +
                for i in range(len(timestamps)):
         | 
| 15 | 
            +
                    next_segment = timestamps[i]
         | 
| 16 | 
            +
             | 
| 17 | 
            +
                    delta = next_segment['start'] - processed_time
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                    # Note that segments can still be longer than the max merge size, they just won't be merged in that case
         | 
| 20 | 
            +
                    if current_segment is None or delta > merge_window or next_segment['end'] - current_segment['start'] > max_merge_size:
         | 
| 21 | 
            +
                        # Finish the current segment
         | 
| 22 | 
            +
                        if current_segment is not None:
         | 
| 23 | 
            +
                            # Add right padding
         | 
| 24 | 
            +
                            finish_padding = min(padding_right, delta / 2) if delta < padding_left + padding_right else padding_right
         | 
| 25 | 
            +
                            current_segment['end'] += finish_padding
         | 
| 26 | 
            +
                            delta -= finish_padding
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                            result.append(current_segment)
         | 
| 29 | 
            +
             | 
| 30 | 
            +
                        # Start a new segment
         | 
| 31 | 
            +
                        current_segment = copy.deepcopy(next_segment)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                        # Pad the segment
         | 
| 34 | 
            +
                        current_segment['start'] = current_segment['start'] - min(padding_left, delta)
         | 
| 35 | 
            +
                        processed_time = current_segment['end']
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                    else:
         | 
| 38 | 
            +
                        # Merge the segment
         | 
| 39 | 
            +
                        current_segment['end'] = next_segment['end']
         | 
| 40 | 
            +
                        processed_time = current_segment['end']
         | 
| 41 | 
            +
                    
         | 
| 42 | 
            +
                # Add the last segment
         | 
| 43 | 
            +
                if current_segment is not None:
         | 
| 44 | 
            +
                    current_segment['end'] += padding_right
         | 
| 45 | 
            +
                    result.append(current_segment)
         | 
| 46 | 
            +
                
         | 
| 47 | 
            +
                return result
         | 
    	
        src/vad.py
    CHANGED
    
    | @@ -5,6 +5,8 @@ from typing import Any, Deque, Iterator, List, Dict | |
| 5 |  | 
| 6 | 
             
            from pprint import pprint
         | 
| 7 |  | 
|  | |
|  | |
| 8 | 
             
            # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
         | 
| 9 | 
             
            try:
         | 
| 10 | 
             
                import tensorflow as tf
         | 
| @@ -110,8 +112,10 @@ class AbstractTranscription(ABC): | |
| 110 | 
             
                    # get speech timestamps from full audio file
         | 
| 111 | 
             
                    seconds_timestamps = self.get_transcribe_timestamps(audio)
         | 
| 112 |  | 
| 113 | 
            -
                     | 
| 114 | 
            -
                     | 
|  | |
|  | |
| 115 |  | 
| 116 | 
             
                    # A deque of transcribed segments that is passed to the next segment as a prompt
         | 
| 117 | 
             
                    prompt_window = deque()
         | 
| @@ -346,70 +350,6 @@ class AbstractTranscription(ABC): | |
| 346 | 
             
                        result.append(new_segment)
         | 
| 347 | 
             
                    return result
         | 
| 348 |  | 
| 349 | 
            -
                def pad_timestamps(self, timestamps: List[Dict[str, Any]], padding_left: float, padding_right: float):
         | 
| 350 | 
            -
                    if (padding_left == 0 and padding_right == 0):
         | 
| 351 | 
            -
                        return timestamps
         | 
| 352 | 
            -
                    
         | 
| 353 | 
            -
                    result = []
         | 
| 354 | 
            -
                    prev_entry = None
         | 
| 355 | 
            -
             | 
| 356 | 
            -
                    for i in range(len(timestamps)):
         | 
| 357 | 
            -
                        curr_entry = timestamps[i]
         | 
| 358 | 
            -
                        next_entry = timestamps[i + 1] if i < len(timestamps) - 1 else None
         | 
| 359 | 
            -
             | 
| 360 | 
            -
                        segment_start = curr_entry['start']
         | 
| 361 | 
            -
                        segment_end = curr_entry['end']
         | 
| 362 | 
            -
             | 
| 363 | 
            -
                        if padding_left is not None:
         | 
| 364 | 
            -
                            segment_start = max(prev_entry['end'] if prev_entry else 0, segment_start - padding_left)
         | 
| 365 | 
            -
                        if padding_right is not None:
         | 
| 366 | 
            -
                            segment_end = segment_end + padding_right
         | 
| 367 | 
            -
             | 
| 368 | 
            -
                            # Do not pad past the next segment
         | 
| 369 | 
            -
                            if (next_entry is not None):
         | 
| 370 | 
            -
                                segment_end = min(next_entry['start'], segment_end)
         | 
| 371 | 
            -
             | 
| 372 | 
            -
                        new_entry = { 'start': segment_start, 'end': segment_end }
         | 
| 373 | 
            -
                        prev_entry = new_entry
         | 
| 374 | 
            -
                        result.append(new_entry)
         | 
| 375 | 
            -
             | 
| 376 | 
            -
                    return result
         | 
| 377 | 
            -
             | 
| 378 | 
            -
                def merge_timestamps(self, timestamps: List[Dict[str, Any]], max_merge_gap: float, max_merge_size: float, 
         | 
| 379 | 
            -
                                            min_force_merge_gap: float, max_force_merge_size: float):
         | 
| 380 | 
            -
                    if max_merge_gap is None:
         | 
| 381 | 
            -
                        return timestamps
         | 
| 382 | 
            -
             | 
| 383 | 
            -
                    result = []
         | 
| 384 | 
            -
                    current_entry = None
         | 
| 385 | 
            -
             | 
| 386 | 
            -
                    for entry in timestamps:
         | 
| 387 | 
            -
                        if current_entry is None:
         | 
| 388 | 
            -
                            current_entry = entry
         | 
| 389 | 
            -
                            continue
         | 
| 390 | 
            -
             | 
| 391 | 
            -
                        # Get distance to the previous entry
         | 
| 392 | 
            -
                        distance = entry['start'] - current_entry['end']
         | 
| 393 | 
            -
                        current_entry_size = current_entry['end'] - current_entry['start']
         | 
| 394 | 
            -
             | 
| 395 | 
            -
                        if distance <= max_merge_gap and (max_merge_size is None or current_entry_size <= max_merge_size):
         | 
| 396 | 
            -
                            # Regular merge
         | 
| 397 | 
            -
                            current_entry['end'] = entry['end']
         | 
| 398 | 
            -
                        elif min_force_merge_gap is not None and distance <= min_force_merge_gap and \
         | 
| 399 | 
            -
                             (max_force_merge_size is None or current_entry_size <= max_force_merge_size):
         | 
| 400 | 
            -
                            # Force merge if the distance is small (up to a certain maximum size)
         | 
| 401 | 
            -
                            current_entry['end'] = entry['end']
         | 
| 402 | 
            -
                        else:
         | 
| 403 | 
            -
                            # Output current entry
         | 
| 404 | 
            -
                            result.append(current_entry)
         | 
| 405 | 
            -
                            current_entry = entry
         | 
| 406 | 
            -
                    
         | 
| 407 | 
            -
                    # Add final entry
         | 
| 408 | 
            -
                    if current_entry is not None:
         | 
| 409 | 
            -
                        result.append(current_entry)
         | 
| 410 | 
            -
             | 
| 411 | 
            -
                    return result
         | 
| 412 | 
            -
             | 
| 413 | 
             
                def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
         | 
| 414 | 
             
                    result = []
         | 
| 415 |  | 
|  | |
| 5 |  | 
| 6 | 
             
            from pprint import pprint
         | 
| 7 |  | 
| 8 | 
            +
            from src.segments import merge_timestamps
         | 
| 9 | 
            +
             | 
| 10 | 
             
            # Workaround for https://github.com/tensorflow/tensorflow/issues/48797
         | 
| 11 | 
             
            try:
         | 
| 12 | 
             
                import tensorflow as tf
         | 
|  | |
| 112 | 
             
                    # get speech timestamps from full audio file
         | 
| 113 | 
             
                    seconds_timestamps = self.get_transcribe_timestamps(audio)
         | 
| 114 |  | 
| 115 | 
            +
                    #for seconds_timestamp in seconds_timestamps:
         | 
| 116 | 
            +
                    #    print("VAD timestamp ", format_timestamp(seconds_timestamp['start']), " to ", format_timestamp(seconds_timestamp['end']))
         | 
| 117 | 
            +
             | 
| 118 | 
            +
                    merged = merge_timestamps(seconds_timestamps, self.max_silent_period, self.max_merge_size, self.segment_padding_left, self.segment_padding_right)
         | 
| 119 |  | 
| 120 | 
             
                    # A deque of transcribed segments that is passed to the next segment as a prompt
         | 
| 121 | 
             
                    prompt_window = deque()
         | 
|  | |
| 350 | 
             
                        result.append(new_segment)
         | 
| 351 | 
             
                    return result
         | 
| 352 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 353 | 
             
                def multiply_timestamps(self, timestamps: List[Dict[str, Any]], factor: float):
         | 
| 354 | 
             
                    result = []
         | 
| 355 |  | 
    	
        tests/segments_test.py
    ADDED
    
    | @@ -0,0 +1,48 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import sys
         | 
| 2 | 
            +
            import unittest
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            sys.path.append('../whisper-webui')
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            from src.segments import merge_timestamps
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            class TestSegments(unittest.TestCase):
         | 
| 9 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 10 | 
            +
                    super(TestSegments, self).__init__(*args, **kwargs)
         | 
| 11 | 
            +
                    
         | 
| 12 | 
            +
                def test_merge_segments(self):
         | 
| 13 | 
            +
                    segments = [
         | 
| 14 | 
            +
                        {'start': 10.0, 'end': 20.0},
         | 
| 15 | 
            +
                        {'start': 22.0, 'end': 27.0},
         | 
| 16 | 
            +
                        {'start': 31.0, 'end': 35.0},
         | 
| 17 | 
            +
                        {'start': 45.0, 'end': 60.0},
         | 
| 18 | 
            +
                        {'start': 61.0, 'end': 65.0},
         | 
| 19 | 
            +
                        {'start': 68.0, 'end': 98.0},
         | 
| 20 | 
            +
                        {'start': 100.0, 'end': 102.0},
         | 
| 21 | 
            +
                        {'start': 110.0, 'end': 112.0}
         | 
| 22 | 
            +
                    ]
         | 
| 23 | 
            +
             | 
| 24 | 
            +
                    result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                    self.assertListEqual(result, [
         | 
| 27 | 
            +
                        {'start': 9.0, 'end': 36.0},
         | 
| 28 | 
            +
                        {'start': 44.0, 'end': 66.0},
         | 
| 29 | 
            +
                        {'start': 67.0, 'end': 99.0},
         | 
| 30 | 
            +
                        {'start': 99.0, 'end': 103.0},
         | 
| 31 | 
            +
                        {'start': 109.0, 'end': 113.0}
         | 
| 32 | 
            +
                    ])
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                def test_overlap_next(self):
         | 
| 35 | 
            +
                    segments = [
         | 
| 36 | 
            +
                        {'start': 5.0, 'end': 39.182},
         | 
| 37 | 
            +
                        {'start': 39.986, 'end': 40.814}
         | 
| 38 | 
            +
                    ]
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                    result = merge_timestamps(segments, merge_window=5, max_merge_size=30, padding_left=1, padding_right=1)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                    self.assertListEqual(result, [
         | 
| 43 | 
            +
                        {'start': 4.0, 'end': 39.584},
         | 
| 44 | 
            +
                        {'start': 39.584, 'end': 41.814}
         | 
| 45 | 
            +
                    ])
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            if __name__ == '__main__':
         | 
| 48 | 
            +
                unittest.main()
         |