File size: 5,532 Bytes
745e5b6
f36e52e
745e5b6
 
 
 
 
 
62b6f11
 
 
 
745e5b6
 
43f1b5e
62b6f11
 
43f1b5e
62b6f11
 
745e5b6
6e73abb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62b6f11
745e5b6
62b6f11
745e5b6
 
 
 
 
 
6e73abb
62b6f11
6e73abb
 
 
 
 
62b6f11
 
 
 
745e5b6
6e73abb
 
 
 
 
 
 
745e5b6
62b6f11
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
6e73abb
 
62b6f11
6e73abb
62b6f11
 
 
 
 
 
 
 
745e5b6
62b6f11
 
 
 
 
5cf5423
 
 
 
 
 
 
62b6f11
 
 
 
 
5cf5423
62b6f11
 
 
 
 
 
 
745e5b6
62b6f11
 
 
 
 
 
 
745e5b6
62b6f11
 
745e5b6
62b6f11
 
 
 
 
 
 
 
6e73abb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces

hf_token = os.getenv("HF_TOKEN")

CHUNK_LENGTH = 5
OVERLAP = 2

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "float32"
whisper_model = None
diarization_pipeline = None

def load_models(model_size="small"):
    global whisper_model, diarization_pipeline
    
    # Load Whisper model
    whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
    
    # Try to initialize diarization pipeline
    try:
        diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
        diarization_pipeline = diarization_pipeline.to(torch.device(device))
    except Exception as e:
        logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
        diarization_pipeline = None

def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
    chunks = []
    for i in range(0, len(audio), chunk_size - overlap):
        chunk = audio[i:i+chunk_size]
        if len(chunk) < chunk_size:
            chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
        chunks.append(chunk)
    return chunks

@spaces.GPU
def process_audio(audio_file, translate=False, model_size="small"):
    global whisper_model, diarization_pipeline
    
    if whisper_model is None:
        load_models(model_size)
    
    start_time = time.time()
    
    try:
        audio = whisperx.load_audio(audio_file)

        # Perform diarization if pipeline is available
        diarization_result = None
        if diarization_pipeline is not None:
            try:
                diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
            except Exception as e:
                logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")

        chunks = preprocess_audio(audio)

        language_segments = []
        final_segments = []
        
        overlap_duration = 2  # 2 seconds overlap
        for i, chunk in enumerate(chunks):
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            chunk_end_time = chunk_start_time + CHUNK_LENGTH
            logger.info(f"Processing chunk {i+1}/{len(chunks)}")
            lang = whisper_model.detect_language(chunk)
            result_transcribe = whisper_model.transcribe(chunk, language=lang)
            if translate:
                result_translate = whisper_model.transcribe(chunk, task="translate")
            chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
            for j, t_seg in enumerate(result_transcribe["segments"]):
                segment_start = chunk_start_time + t_seg["start"]
                segment_end = chunk_start_time + t_seg["end"]
                # Skip segments in the overlapping region of the previous chunk
                if i > 0 and segment_end <= chunk_start_time + overlap_duration:
                    print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                # Skip segments in the overlapping region of the next chunk
                if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
                    print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
                    continue
            
                speaker = "Unknown"
                if diarization_result is not None:
                    speakers = []
                    for turn, track, spk in diarization_result.itertracks(yield_label=True):
                        if turn.start <= segment_end and turn.end >= segment_start:
                            speakers.append(spk)
                    speaker = max(set(speakers), key=speakers.count) if speakers else "Unknown"
            
                segment = {
                    "start": segment_start,
                    "end": segment_end,
                    "language": lang,
                    "speaker": speaker,
                    "text": t_seg["text"],
                }
                
                if translate:
                    segment["translated"] = result_translate["segments"][j]["text"]
            
                final_segments.append(segment)

            language_segments.append({
                "language": lang,
                "start": chunk_start_time,
                "end": chunk_start_time + CHUNK_LENGTH
            })
            chunk_end_time = time.time()
            logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")

        final_segments.sort(key=lambda x: x["start"])
        merged_segments = merge_nearby_segments(final_segments)

        end_time = time.time()
        logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")

        return language_segments, merged_segments
    except Exception as e:
        logger.error(f"An error occurred during audio processing: {str(e)}")
        raise

# The merge_nearby_segments and print_results functions remain unchanged