Spaces:
Build error
Build error
File size: 6,656 Bytes
745e5b6 f36e52e 745e5b6 62b6f11 745e5b6 43f1b5e d707938 62b6f11 43f1b5e 759bce7 62b6f11 745e5b6 6e73abb 759bce7 51a5dfa 759bce7 51a5dfa 759bce7 fd4f883 759bce7 fd4f883 759bce7 6e73abb 759bce7 6e73abb 759bce7 6e73abb 62b6f11 745e5b6 62b6f11 745e5b6 759bce7 fd4f883 b34b775 759bce7 62b6f11 759bce7 62b6f11 fd4f883 745e5b6 6e73abb fd4f883 759bce7 fd4f883 745e5b6 62b6f11 745e5b6 fd4f883 745e5b6 fd4f883 62b6f11 fd4f883 759bce7 fd4f883 5cf5423 fd4f883 62b6f11 5cf5423 fd4f883 62b6f11 fd4f883 62b6f11 fd4f883 62b6f11 745e5b6 62b6f11 745e5b6 62b6f11 0b6f315 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 |
import whisperx
import torch
import numpy as np
from scipy.signal import resample
from pyannote.audio import Pipeline
import os
from dotenv import load_dotenv
load_dotenv()
import logging
import time
from difflib import SequenceMatcher
import spaces
hf_token = os.getenv("HF_TOKEN")
CHUNK_LENGTH = 30
OVERLAP = 2
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
@spaces.GPU(duration=60)
def load_whisper_model(model_size="small"):
logger.info(f"Loading Whisper model (size: {model_size})...")
device = "cuda" if torch.cuda.is_available() else "cpu"
compute_type = "float16" if device == "cuda" else "int8"
try:
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info(f"Whisper model loaded successfully on {device}")
return model
except RuntimeError as e:
logger.warning(f"Failed to load Whisper model on {device}. Falling back to CPU. Error: {str(e)}")
device = "cpu"
compute_type = "int8"
model = whisperx.load_model(model_size, device, compute_type=compute_type)
logger.info("Whisper model loaded successfully on CPU")
return model
@spaces.GPU(duration=60)
def load_diarization_pipeline():
logger.info("Loading diarization pipeline...")
try:
pipeline = Pipeline.from_pretrained("pyannote/speaker-diarization", use_auth_token=hf_token)
if torch.cuda.is_available():
pipeline = pipeline.to(torch.device("cuda"))
logger.info("Diarization pipeline loaded successfully")
return pipeline
except Exception as e:
logger.warning(f"Diarization pipeline initialization failed: {str(e)}. Diarization will not be available.")
return None
def preprocess_audio(audio, chunk_size=CHUNK_LENGTH*16000, overlap=OVERLAP*16000):
chunks = []
for i in range(0, len(audio), chunk_size - overlap):
chunk = audio[i:i+chunk_size]
if len(chunk) < chunk_size:
chunk = np.pad(chunk, (0, chunk_size - len(chunk)))
chunks.append(chunk)
return chunks
def merge_nearby_segments(segments, time_threshold=0.5, similarity_threshold=0.7):
merged = []
for segment in segments:
if not merged or segment['start'] - merged[-1]['end'] > time_threshold:
merged.append(segment)
else:
matcher = SequenceMatcher(None, merged[-1]['text'], segment['text'])
match = matcher.find_longest_match(0, len(merged[-1]['text']), 0, len(segment['text']))
if match.size / len(segment['text']) > similarity_threshold:
merged_text = merged[-1]['text'] + segment['text'][match.b + match.size:]
merged_translated = merged[-1].get('translated', '') + segment.get('translated', '')[match.b + match.size:]
merged[-1]['end'] = segment['end']
merged[-1]['text'] = merged_text
if 'translated' in segment:
merged[-1]['translated'] = merged_translated
else:
merged.append(segment)
return merged
def get_most_common_speaker(diarization_result, start_time, end_time):
speakers = []
for turn, _, speaker in diarization_result.itertracks(yield_label=True):
if turn.start <= end_time and turn.end >= start_time:
speakers.append(speaker)
return max(set(speakers), key=speakers.count) if speakers else "Unknown"
def split_audio(audio, max_duration=30):
sample_rate = 16000
max_samples = max_duration * sample_rate
if len(audio) <= max_samples:
return [audio]
splits = []
for i in range(0, len(audio), max_samples):
splits.append(audio[i:i+max_samples])
return splits
@spaces.GPU(duration=60)
def process_audio(audio_file, translate=False, model_size="small", use_diarization=True):
logger.info(f"Starting audio processing: translate={translate}, model_size={model_size}, use_diarization={use_diarization}")
start_time = time.time()
try:
whisper_model = load_whisper_model(model_size)
audio = whisperx.load_audio(audio_file)
audio_splits = split_audio(audio)
diarization_result = None
if use_diarization:
diarization_pipeline = load_diarization_pipeline()
if diarization_pipeline is not None:
try:
diarization_result = diarization_pipeline({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": 16000})
except Exception as e:
logger.warning(f"Diarization failed: {str(e)}. Proceeding without diarization.")
language_segments = []
final_segments = []
for i, audio_split in enumerate(audio_splits):
logger.info(f"Processing split {i+1}/{len(audio_splits)}")
result = whisper_model.transcribe(audio_split)
lang = result["language"]
for segment in result["segments"]:
segment_start = segment["start"] + (i * 30)
segment_end = segment["end"] + (i * 30)
speaker = "Unknown"
if diarization_result is not None:
speaker = get_most_common_speaker(diarization_result, segment_start, segment_end)
final_segment = {
"start": segment_start,
"end": segment_end,
"language": lang,
"speaker": speaker,
"text": segment["text"],
}
if translate:
translation = whisper_model.transcribe(audio_split[int(segment["start"]*16000):int(segment["end"]*16000)], task="translate")
final_segment["translated"] = translation["text"]
final_segments.append(final_segment)
language_segments.append({
"language": lang,
"start": i * 30,
"end": min((i + 1) * 30, len(audio) / 16000)
})
final_segments.sort(key=lambda x: x["start"])
merged_segments = merge_nearby_segments(final_segments)
end_time = time.time()
logger.info(f"Total processing time: {end_time - start_time:.2f} seconds")
return language_segments, merged_segments
except Exception as e:
logger.error(f"An error occurred during audio processing: {str(e)}")
raise |