ASR_gradio / audio_processing.py
Kr08's picture
Update audio_processing.py
6e73abb verified
raw
history blame
5.53 kB
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 5
OVERLAP = 2
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Global variables for models
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "float32"
whisper_model = None
diarization_pipeline = None
def load_models(model_size="small"):
global whisper_model, diarization_pipeline
# Load Whisper model
whisper_model = whisperx.load_model(model_size, device, compute_type=compute_type)
# Try to initialize diarization pipeline
try:
diarization_pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
diarization_pipeline = diarization_pipeline.to(torch.device(device))
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
diarization_pipeline = None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
@spaces.GPU
def process_audio(audio_file, translate=False, model_size="small"):
global whisper_model, diarization_pipeline
if whisper_model is None:
load_models(model_size)
start_time = time.time()
try:
audio = whisperx.load_audio(audio_file)
# Perform diarization if pipeline is available
diarization_result = None
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
chunks = preprocess_audio(audio)
language_segments = []
final_segments = []
overlap_duration = 2 # 2 seconds overlap
for i, chunk in enumerate(chunks):
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
chunk_end_time = chunk_start_time + CHUNK_LENGTH
logger.info(f"Processing chunk {i+1}/{len(chunks)}")
lang = whisper_model.detect_language(chunk)
result_transcribe = whisper_model.transcribe(chunk, language=lang)
if translate:
result_translate = whisper_model.transcribe(chunk, task="translate")
chunk_start_time = i * (CHUNK_LENGTH - overlap_duration)
for j, t_seg in enumerate(result_transcribe["segments"]):
segment_start = chunk_start_time + t_seg["start"]
segment_end = chunk_start_time + t_seg["end"]
# Skip segments in the overlapping region of the previous chunk
if i > 0 and segment_end <= chunk_start_time + overlap_duration:
print(f"Skipping segment in overlap with previous chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
# Skip segments in the overlapping region of the next chunk
if i < len(chunks) - 1 and segment_start >= chunk_end_time - overlap_duration:
print(f"Skipping segment in overlap with next chunk: {segment_start:.2f} - {segment_end:.2f}")
continue
speaker = "Unknown"
if diarization_result is not None:
speakers = []
for turn, track, spk in diarization_result.itertracks(yield_label=True):
if turn.start <= segment_end and turn.end >= segment_start:
speakers.append(spk)
speaker = max(set(speakers), key=speakers.count) if speakers else "Unknown"
segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": t_seg["text"],
}
if translate:
segment["translated"] = result_translate["segments"][j]["text"]
final_segments.append(segment)
language_segments.append({
"language": lang,
"start": chunk_start_time,
"end": chunk_start_time + CHUNK_LENGTH
})
chunk_end_time = time.time()
logger.info(f"Chunk {i+1} processed in {chunk_end_time - chunk_start_time:.2f} seconds")
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise
# The merge_nearby_segments and print_results functions remain unchanged