File size: 5,913 Bytes
745e5b6
f36e52e
745e5b6
 
 
 
 
 
62b6f11
 
 
 
745e5b6
 
43f1b5e
62b6f11
 
43f1b5e
62b6f11
 
745e5b6
62b6f11
745e5b6
62b6f11
745e5b6
 
 
 
 
 
62b6f11
 
 
 
 
 
 
 
 
745e5b6
62b6f11
 
745e5b6
62b6f11
745e5b6
62b6f11
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745e5b6
62b6f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745e5b6
62b6f11
 
 
 
 
 
 
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
745e5b6
62b6f11
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces

hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH = 5
OVERLAP = 2

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks

@spaces.GPU
def process_audio(audio_file, translate=False, model_size="small"):
    start_time = time.time()
    
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        compute_type = "float16" if device == "cuda" else "float32"
        audio = whisperx.load_audio(audio_file)
        model = whisperx.load_model(model_size, device, compute_type=compute_type)

        diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        diarization_pipeline = diarization_pipeline.to(torch.device(device))

        diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})

        chunks = preprocess_audio(audio)

        language_segments = []
        final_segments = []
        
        overlap_duration = 2  # 2 seconds overlap
        for i, chunk in enumerate(chunks):
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            chunk_end_time = chunk_start_time + CHUNK_LENGTH
            logger.info(f"Processing chunk {i+1}/{len(chunks)}")
            lang = model.detect_language(chunk)
            result_transcribe = model.transcribe(chunk, language=lang)
            if translate:
                result_translate = model.transcribe(chunk, task="translate")
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            for j, t_seg in enumerate(result_transcribe["segments"]):
                segment_start = chunk_start_time + t_seg["start"]
                segment_end = chunk_start_time + t_seg["end"]
                # Skip segments in the overlapping region of the previous chunk
                if i > 0 and segment_end <= chunk_start_time + overlap_duration:
                    print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                # Skip segments in the overlapping region of the next chunk
                if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
                    print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                speakers = []
                for turn, track, speaker in diarization_result.itertracks(yield_label=True):
                    if turn.start <= segment_end and turn.end >= segment_start:
                        speakers.append(speaker)
            
                segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": max(set(speakers), key=speakers.count) if speakers else "Unknown",
                    "text": t_seg["text"],
                }
                
                if translate:
                    segment["translated"] = result_translate["segments"][j]["text"]
            
                final_segments.append(segment)

            language_segments.append({
                "language": lang,
                "start": chunk_start_time,
                "end": chunk_start_time + CHUNK_LENGTH
            })
            chunk_end_time = time.time()
            logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, merged_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise

def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
    merged = []
    for segment in segments:
        if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
            merged.append(segment)
        else:
            # Find the overlap
            matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
            match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
            
            if match.size / len(segment['text']) > similarity_threshold:
                # Merge the segments
                merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
                merged_translated = merged[-1]['translated'] + segment['translated'][match.b + match.size:]
                
                merged[-1]['end'] = segment['end']
                merged[-1]['text'] = merged_text
                merged[-1]['translated'] = merged_translated
            else:
                # If no significant overlap, append as a new segment
                merged.append(segment)
    return merged

def print_results(segments):
    for segment in segments:
        print(f"[{segment['start']:.2f}s - {segment['end']:.2f}s] ({segment['language']}) {segment['speaker']}:")
        print(f"Original: {segment['text']}")
        if 'translated' in segment:
            print(f"Translated: {segment['translated']}")
        print()