File size: 6,656 Bytes
745e5b6
f36e52e
745e5b6
 
 
 
 
 
62b6f11
 
 
 
745e5b6
 
43f1b5e
d707938
62b6f11
43f1b5e
759bce7
62b6f11
745e5b6
6e73abb
759bce7
 
 
 
 
51a5dfa
759bce7
 
 
51a5dfa
 
 
 
759bce7
 
 
 
fd4f883
759bce7
fd4f883
759bce7
6e73abb
759bce7
 
 
 
 
6e73abb
 
759bce7
 
6e73abb
62b6f11
745e5b6
62b6f11
745e5b6
 
 
 
 
 
759bce7
fd4f883
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b34b775
 
759bce7
62b6f11
 
 
759bce7
62b6f11
fd4f883
745e5b6
6e73abb
fd4f883
759bce7
fd4f883
 
 
 
 
745e5b6
62b6f11
 
745e5b6
fd4f883
 
745e5b6
fd4f883
 
62b6f11
fd4f883
759bce7
 
fd4f883
5cf5423
 
fd4f883
 
 
62b6f11
 
 
5cf5423
fd4f883
62b6f11
 
 
fd4f883
 
 
 
62b6f11
 
 
fd4f883
 
62b6f11
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
 
0b6f315
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces

hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH = 30
OVERLAP = 2

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)


@spaces.GPU(duration=60)
def load_whisper_model(model_size="small"):
    logger.info(f"Loading Whisper model (size: {model_size})...")
    device = "cuda" if torch.cuda.is_available() else "cpu"
    compute_type = "float16" if device == "cuda" else "int8"
    try:
        model = whisperx.load_model(model_size, device, compute_type=compute_type)
        logger.info(f"Whisper model loaded successfully on {device}")
        return model
    except RuntimeError as e:
        logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
        device = "cpu"
        compute_type = "int8"
        model = whisperx.load_model(model_size, device, compute_type=compute_type)
        logger.info("Whisper model loaded successfully on CPU")
        return model


@spaces.GPU(duration=60)
def load_diarization_pipeline():
    logger.info("Loading diarization pipeline...")
    try:
        pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        if torch.cuda.is_available():
            pipeline = pipeline.to(torch.device("cuda"))
        logger.info("Diarization pipeline loaded successfully")
        return pipeline
    except Exception as e:
        logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
        return None


def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks


def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
    merged = []
    for segment in segments:
        if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
            merged.append(segment)
        else:
            matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
            match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
            
            if match.size / len(segment['text']) > similarity_threshold:
                merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
                merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
                
                merged[-1]['end'] = segment['end']
                merged[-1]['text'] = merged_text
                if 'translated' in segment:
                    merged[-1]['translated'] = merged_translated
            else:
                merged.append(segment)
    return merged

def get_most_common_speaker(diarization_result, start_time, end_time):
    speakers = []
    for turn, _, speaker in diarization_result.itertracks(yield_label=True):
        if turn.start <= end_time and turn.end >= start_time:
            speakers.append(speaker)
    return max(set(speakers), key=speakers.count) if speakers else "Unknown"

def split_audio(audio, max_duration=30):
    sample_rate = 16000
    max_samples = max_duration * sample_rate
    
    if len(audio) <= max_samples:
        return [audio]
    
    splits = []
    for i in range(0, len(audio), max_samples):
        splits.append(audio[i:i+max_samples])
    
    return splits

@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
    logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
    start_time = time.time()
    
    try:
        whisper_model = load_whisper_model(model_size)
        audio = whisperx.load_audio(audio_file)
        audio_splits = split_audio(audio)

        diarization_result = None
        if use_diarization:
            diarization_pipeline = load_diarization_pipeline()
            if diarization_pipeline is not None:
                try:
                    diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
                except Exception as e:
                    logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")

        language_segments = []
        final_segments = []
        
        for i, audio_split in enumerate(audio_splits):
            logger.info(f"Processing split {i+1}/{len(audio_splits)}")
            
            result = whisper_model.transcribe(audio_split)
            lang = result["language"]
            
            for segment in result["segments"]:
                segment_start = segment["start"] + (i * 30)
                segment_end = segment["end"] + (i * 30)
                
                speaker = "Unknown"
                if diarization_result is not None:
                    speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
                
                final_segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": speaker,
                    "text": segment["text"],
                }
                
                if translate:
                    translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
                    final_segment["translated"] = translation["text"]
                
                final_segments.append(final_segment)
            
            language_segments.append({
                "language": lang,
                "start": i * 30,
                "end": min((i + 1) * 30, len(audio) / 16000)
            })

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, merged_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise