Spaces:
Build error
Build error
File size: 5,532 Bytes
745e5b6 f36e52e 745e5b6 62b6f11 745e5b6 43f1b5e 62b6f11 43f1b5e 62b6f11 745e5b6 6e73abb 62b6f11 745e5b6 62b6f11 745e5b6 6e73abb 62b6f11 6e73abb 62b6f11 745e5b6 6e73abb 745e5b6 62b6f11 745e5b6 62b6f11 745e5b6 62b6f11 6e73abb 62b6f11 6e73abb 62b6f11 745e5b6 62b6f11 5cf5423 62b6f11 5cf5423 62b6f11 745e5b6 62b6f11 745e5b6 62b6f11 745e5b6 62b6f11 6e73abb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 |
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 5
OVERLAP = 2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "float32"
whisper_model = None
diarization_pipeline = None
def load_models(model_size="small"):
global whisper_model, diarization_pipeline
# Load Whisper model
whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
# Try to initialize diarization pipeline
try:
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization_pipeline = diarization_pipeline.to(torch.device(device))
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
diarization_pipeline = None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
@spaces.GPU
def process_audio(audio_file, translate=False, model_size="small"):
global whisper_model, diarization_pipeline
if whisper_model is None:
load_models(model_size)
start_time = time.time()
try:
audio = whisperx.load_audio(audio_file)
# Perform diarization if pipeline is available
diarization_result = None
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
chunks = preprocess_audio(audio)
language_segments = []
final_segments = []
overlap_duration = 2 # 2 seconds overlap
for i, chunk in enumerate(chunks):
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
chunk_end_time = chunk_start_time + CHUNK_LENGTH
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
lang = whisper_model.detect_language(chunk)
result_transcribe = whisper_model.transcribe(chunk, language=lang)
if translate:
result_translate = whisper_model.transcribe(chunk, task="translate")
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
for j, t_seg in enumerate(result_transcribe["segments"]):
segment_start = chunk_start_time + t_seg["start"]
segment_end = chunk_start_time + t_seg["end"]
# Skip segments in the overlapping region of the previous chunk
if i > 0 and segment_end <= chunk_start_time + overlap_duration:
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
# Skip segments in the overlapping region of the next chunk
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
speaker = "Unknown"
if diarization_result is not None:
speakers = []
for turn, track, spk in diarization_result.itertracks(yield_label=True):
if turn.start <= segment_end and turn.end >= segment_start:
speakers.append(spk)
speaker = max(set(speakers), key=speakers.count) if speakers else "Unknown"
segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": t_seg["text"],
}
if translate:
segment["translated"] = result_translate["segments"][j]["text"]
final_segments.append(segment)
language_segments.append({
"language": lang,
"start": chunk_start_time,
"end": chunk_start_time + CHUNK_LENGTH
})
chunk_end_time = time.time()
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
# The merge_nearby_segments and print_results functions remain unchanged |