File size: 6,863 Bytes
745e5b6
f36e52e
745e5b6
 
 
 
 
 
62b6f11
 
 
 
745e5b6
 
43f1b5e
d707938
62b6f11
43f1b5e
62b6f11
 
745e5b6
6e73abb
 
51a5dfa
6e73abb
 
 
 
51a5dfa
6e73abb
 
51a5dfa
 
 
 
 
 
 
fd4f883
 
 
6e73abb
 
 
 
51a5dfa
 
6e73abb
 
 
 
62b6f11
745e5b6
62b6f11
745e5b6
 
 
 
 
 
fd4f883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b34b775
 
6e73abb
 
 
 
 
62b6f11
 
 
 
fd4f883
745e5b6
fd4f883
6e73abb
fd4f883
 
 
 
 
 
 
 
745e5b6
62b6f11
 
745e5b6
fd4f883
 
745e5b6
fd4f883
 
62b6f11
fd4f883
 
 
 
5cf5423
 
fd4f883
 
 
62b6f11
 
 
5cf5423
fd4f883
62b6f11
 
 
fd4f883
 
 
 
62b6f11
 
 
fd4f883
 
62b6f11
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
 
0b6f315
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces

hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH = 30
OVERLAP = 2

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
whisper_model = None
diarization_pipeline = None

def load_models(model_size="small"):
    global whisper_model, diarization_pipeline, device, compute_type
    
    # Load Whisper model
    try:
        whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
    except RuntimeError as e:
        logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
        device = "cpu"
        compute_type = "int8"
        whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)

def load_diarization_pipeline():
    global diarization_pipeline, device
    
    # Try to initialize diarization pipeline
    try:
        diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        if device == "cuda":
            diarization_pipeline = diarization_pipeline.to(torch.device(device))
    except Exception as e:
        logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
        diarization_pipeline = None

def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks

def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
    merged = []
    for segment in segments:
        if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
            merged.append(segment)
        else:
            # Find the overlap
            matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
            match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
            
            if match.size / len(segment['text']) > similarity_threshold:
                # Merge the segments
                merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
                merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
                
                merged[-1]['end'] = segment['end']
                merged[-1]['text'] = merged_text
                if 'translated' in segment:
                    merged[-1]['translated'] = merged_translated
            else:
                # If no significant overlap, append as a new segment
                merged.append(segment)
    return merged

# Helper function to get the most common speaker in a time range
def get_most_common_speaker(diarization_result, start_time, end_time):
    speakers = []
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        if turn.start <= end_time and turn.end >= start_time:
            speakers.append(speaker)
    return max(set(speakers), key=speakers.count) if speakers else "Unknown"

# Helper function to split long audio files
def split_audio(audio, max_duration=30):
    sample_rate = 16000
    max_samples = max_duration * sample_rate
    
    if len(audio) <= max_samples:
        return [audio]
    
    splits = []
    for i in range(0, len(audio), max_samples):
        splits.append(audio[i:i+max_samples])
    
    return splits

# Main processing function with optimizations
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
    global whisper_model, diarization_pipeline
    
    if whisper_model is None:
        load_models(model_size)
    
    start_time = time.time()
    
    try:
        audio = whisperx.load_audio(audio_file)
        audio_splits = split_audio(audio)

        # Perform diarization if requested and pipeline is available
        diarization_result = None
        if use_diarization:
            if diarization_pipeline is None:
                load_diarization_pipeline()
            if diarization_pipeline is not None:
                try:
                    diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
                except Exception as e:
                    logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")

        language_segments = []
        final_segments = []
        
        for i, audio_split in enumerate(audio_splits):
            logger.info(f"Processing split {i+1}/{len(audio_splits)}")
            
            result = whisper_model.transcribe(audio_split)
            lang = result["language"]
            
            for segment in result["segments"]:
                segment_start = segment["start"] + (i * 30)  # Adjust start time based on split
                segment_end = segment["end"] + (i * 30)  # Adjust end time based on split
                
                speaker = "Unknown"
                if diarization_result is not None:
                    speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
                
                final_segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": speaker,
                    "text": segment["text"],
                }
                
                if translate:
                    translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
                    final_segment["translated"] = translation["text"]
                
                final_segments.append(final_segment)
            
            language_segments.append({
                "language": lang,
                "start": i * 30,
                "end": min((i + 1) * 30, len(audio) / 16000)
            })

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, merged_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise