|
import gradio as gr |
|
import numpy as np |
|
import torch |
|
import torchaudio |
|
import threading |
|
import queue |
|
import time |
|
import os |
|
import urllib.request |
|
from scipy.spatial.distance import cosine |
|
from collections import deque |
|
import tempfile |
|
import librosa |
|
|
|
|
|
FINAL_TRANSCRIPTION_MODEL = "openai/whisper-small" |
|
TRANSCRIPTION_LANGUAGE = "en" |
|
DEFAULT_CHANGE_THRESHOLD = 0.7 |
|
EMBEDDING_HISTORY_SIZE = 5 |
|
MIN_SEGMENT_DURATION = 1.0 |
|
DEFAULT_MAX_SPEAKERS = 4 |
|
ABSOLUTE_MAX_SPEAKERS = 6 |
|
SAMPLE_RATE = 16000 |
|
|
|
|
|
SPEAKER_COLORS = [ |
|
"#FFD700", |
|
"#FF6B6B", |
|
"#4ECDC4", |
|
"#45B7D1", |
|
"#96CEB4", |
|
"#FFEAA7", |
|
] |
|
|
|
SPEAKER_COLOR_NAMES = [ |
|
"Gold", "Red", "Teal", "Blue", "Green", "Yellow" |
|
] |
|
|
|
|
|
class SpeechBrainEncoder: |
|
"""Simplified encoder for speaker embeddings using torch audio features""" |
|
def __init__(self, device="cpu"): |
|
self.device = device |
|
self.embedding_dim = 128 |
|
self.model_loaded = True |
|
|
|
def load_model(self): |
|
"""Model loading simulation""" |
|
return True |
|
|
|
def embed_utterance(self, audio, sr=16000): |
|
"""Extract simple spectral features as speaker embedding""" |
|
try: |
|
if isinstance(audio, np.ndarray): |
|
waveform = torch.tensor(audio, dtype=torch.float32) |
|
else: |
|
waveform = audio |
|
|
|
if len(waveform.shape) == 1: |
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
if sr != 16000: |
|
waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000) |
|
|
|
|
|
mfcc_transform = torchaudio.transforms.MFCC( |
|
sample_rate=16000, |
|
n_mfcc=13, |
|
melkwargs={'n_mels': 40} |
|
) |
|
|
|
mfcc = mfcc_transform(waveform) |
|
|
|
embedding = mfcc.mean(dim=2).flatten() |
|
|
|
|
|
if len(embedding) > self.embedding_dim: |
|
embedding = embedding[:self.embedding_dim] |
|
elif len(embedding) < self.embedding_dim: |
|
padding = torch.zeros(self.embedding_dim - len(embedding)) |
|
embedding = torch.cat([embedding, padding]) |
|
|
|
return embedding.numpy() |
|
|
|
except Exception as e: |
|
print(f"Error extracting embedding: {e}") |
|
return np.random.randn(self.embedding_dim) |
|
|
|
|
|
class SpeakerChangeDetector: |
|
"""Speaker change detector for real-time diarization""" |
|
def __init__(self, embedding_dim=128, change_threshold=DEFAULT_CHANGE_THRESHOLD, max_speakers=DEFAULT_MAX_SPEAKERS): |
|
self.embedding_dim = embedding_dim |
|
self.change_threshold = change_threshold |
|
self.max_speakers = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
self.current_speaker = 0 |
|
self.previous_embeddings = [] |
|
self.last_change_time = time.time() |
|
self.mean_embeddings = [None] * self.max_speakers |
|
self.speaker_embeddings = [[] for _ in range(self.max_speakers)] |
|
self.last_similarity = 0.0 |
|
self.active_speakers = set([0]) |
|
|
|
def set_max_speakers(self, max_speakers): |
|
"""Update the maximum number of speakers""" |
|
new_max = min(max_speakers, ABSOLUTE_MAX_SPEAKERS) |
|
|
|
if new_max < self.max_speakers: |
|
for speaker_id in list(self.active_speakers): |
|
if speaker_id >= new_max: |
|
self.active_speakers.discard(speaker_id) |
|
if self.current_speaker >= new_max: |
|
self.current_speaker = 0 |
|
|
|
if new_max > self.max_speakers: |
|
self.mean_embeddings.extend([None] * (new_max - self.max_speakers)) |
|
self.speaker_embeddings.extend([[] for _ in range(new_max - self.max_speakers)]) |
|
else: |
|
self.mean_embeddings = self.mean_embeddings[:new_max] |
|
self.speaker_embeddings = self.speaker_embeddings[:new_max] |
|
|
|
self.max_speakers = new_max |
|
|
|
def set_change_threshold(self, threshold): |
|
"""Update the threshold for detecting speaker changes""" |
|
self.change_threshold = max(0.1, min(threshold, 0.99)) |
|
|
|
def add_embedding(self, embedding, timestamp=None): |
|
"""Add a new embedding and check if there's a speaker change""" |
|
current_time = timestamp or time.time() |
|
|
|
if not self.previous_embeddings: |
|
self.previous_embeddings.append(embedding) |
|
self.speaker_embeddings[self.current_speaker].append(embedding) |
|
if self.mean_embeddings[self.current_speaker] is None: |
|
self.mean_embeddings[self.current_speaker] = embedding.copy() |
|
return self.current_speaker, 1.0 |
|
|
|
current_mean = self.mean_embeddings[self.current_speaker] |
|
if current_mean is not None: |
|
similarity = 1.0 - cosine(embedding, current_mean) |
|
else: |
|
similarity = 1.0 - cosine(embedding, self.previous_embeddings[-1]) |
|
|
|
self.last_similarity = similarity |
|
|
|
time_since_last_change = current_time - self.last_change_time |
|
is_speaker_change = False |
|
|
|
if time_since_last_change >= MIN_SEGMENT_DURATION: |
|
if similarity < self.change_threshold: |
|
best_speaker = self.current_speaker |
|
best_similarity = similarity |
|
|
|
for speaker_id in range(self.max_speakers): |
|
if speaker_id == self.current_speaker: |
|
continue |
|
|
|
speaker_mean = self.mean_embeddings[speaker_id] |
|
|
|
if speaker_mean is not None: |
|
speaker_similarity = 1.0 - cosine(embedding, speaker_mean) |
|
if speaker_similarity > best_similarity: |
|
best_similarity = speaker_similarity |
|
best_speaker = speaker_id |
|
|
|
if best_speaker != self.current_speaker: |
|
is_speaker_change = True |
|
self.current_speaker = best_speaker |
|
elif len(self.active_speakers) < self.max_speakers: |
|
for new_id in range(self.max_speakers): |
|
if new_id not in self.active_speakers: |
|
is_speaker_change = True |
|
self.current_speaker = new_id |
|
self.active_speakers.add(new_id) |
|
break |
|
|
|
if is_speaker_change: |
|
self.last_change_time = current_time |
|
|
|
self.previous_embeddings.append(embedding) |
|
if len(self.previous_embeddings) > EMBEDDING_HISTORY_SIZE: |
|
self.previous_embeddings.pop(0) |
|
|
|
self.speaker_embeddings[self.current_speaker].append(embedding) |
|
self.active_speakers.add(self.current_speaker) |
|
|
|
if len(self.speaker_embeddings[self.current_speaker]) > 30: |
|
self.speaker_embeddings[self.current_speaker] = self.speaker_embeddings[self.current_speaker][-30:] |
|
|
|
if self.speaker_embeddings[self.current_speaker]: |
|
self.mean_embeddings[self.current_speaker] = np.mean( |
|
self.speaker_embeddings[self.current_speaker], axis=0 |
|
) |
|
|
|
return self.current_speaker, similarity |
|
|
|
def get_color_for_speaker(self, speaker_id): |
|
"""Return color for speaker ID""" |
|
if 0 <= speaker_id < len(SPEAKER_COLORS): |
|
return SPEAKER_COLORS[speaker_id] |
|
return "#FFFFFF" |
|
|
|
|
|
class RealTimeASRDiarization: |
|
"""Main class for real-time ASR with speaker diarization""" |
|
def __init__(self): |
|
self.encoder = SpeechBrainEncoder() |
|
self.encoder.load_model() |
|
self.speaker_detector = SpeakerChangeDetector() |
|
self.transcription_queue = queue.Queue() |
|
self.conversation_history = [] |
|
self.is_processing = False |
|
|
|
|
|
try: |
|
import whisper |
|
self.whisper_model = whisper.load_model("base") |
|
except ImportError: |
|
print("Whisper not available, using mock transcription") |
|
self.whisper_model = None |
|
|
|
def transcribe_audio(self, audio_data, sr=16000): |
|
"""Transcribe audio using Whisper""" |
|
try: |
|
if self.whisper_model is None: |
|
return "Mock transcription: Hello, this is a test." |
|
|
|
|
|
if isinstance(audio_data, tuple): |
|
sr, audio_data = audio_data |
|
|
|
if len(audio_data.shape) > 1: |
|
audio_data = audio_data.mean(axis=1) |
|
|
|
|
|
audio_data = audio_data.astype(np.float32) |
|
if np.abs(audio_data).max() > 1.0: |
|
audio_data = audio_data / np.abs(audio_data).max() |
|
|
|
|
|
if sr != 16000: |
|
audio_data = librosa.resample(audio_data, orig_sr=sr, target_sr=16000) |
|
|
|
|
|
result = self.whisper_model.transcribe(audio_data, language="en") |
|
return result["text"].strip() |
|
|
|
except Exception as e: |
|
print(f"Transcription error: {e}") |
|
return "" |
|
|
|
def extract_speaker_embedding(self, audio_data, sr=16000): |
|
"""Extract speaker embedding from audio""" |
|
return self.encoder.embed_utterance(audio_data, sr) |
|
|
|
def process_audio_segment(self, audio_data, sr=16000): |
|
"""Process an audio segment for transcription and speaker identification""" |
|
if len(audio_data) < sr * 0.5: |
|
return None, None, None |
|
|
|
|
|
transcription = self.transcribe_audio(audio_data, sr) |
|
|
|
if not transcription: |
|
return None, None, None |
|
|
|
|
|
embedding = self.extract_speaker_embedding(audio_data, sr) |
|
|
|
|
|
speaker_id, similarity = self.speaker_detector.add_embedding(embedding) |
|
|
|
return transcription, speaker_id, similarity |
|
|
|
def update_conversation(self, transcription, speaker_id): |
|
"""Update conversation history with new transcription""" |
|
speaker_name = f"Speaker {speaker_id + 1}" |
|
color = self.speaker_detector.get_color_for_speaker(speaker_id) |
|
|
|
entry = { |
|
"speaker": speaker_name, |
|
"text": transcription, |
|
"color": color, |
|
"timestamp": time.time() |
|
} |
|
|
|
self.conversation_history.append(entry) |
|
return entry |
|
|
|
def format_conversation_html(self): |
|
"""Format conversation history as HTML""" |
|
if not self.conversation_history: |
|
return "<p><i>No conversation yet. Start speaking to see real-time transcription with speaker diarization.</i></p>" |
|
|
|
html_parts = [] |
|
for entry in self.conversation_history: |
|
html_parts.append( |
|
f'<p><span style="color: {entry["color"]}; font-weight: bold;">' |
|
f'{entry["speaker"]}:</span> {entry["text"]}</p>' |
|
) |
|
|
|
return "".join(html_parts) |
|
|
|
def get_status_info(self): |
|
"""Get current status information""" |
|
status = { |
|
"active_speakers": len(self.speaker_detector.active_speakers), |
|
"max_speakers": self.speaker_detector.max_speakers, |
|
"current_speaker": self.speaker_detector.current_speaker + 1, |
|
"total_segments": len(self.conversation_history), |
|
"threshold": self.speaker_detector.change_threshold |
|
} |
|
return status |
|
|
|
def clear_conversation(self): |
|
"""Clear conversation history and reset speaker detector""" |
|
self.conversation_history = [] |
|
self.speaker_detector = SpeakerChangeDetector( |
|
change_threshold=self.speaker_detector.change_threshold, |
|
max_speakers=self.speaker_detector.max_speakers |
|
) |
|
|
|
def set_parameters(self, threshold, max_speakers): |
|
"""Update parameters""" |
|
self.speaker_detector.set_change_threshold(threshold) |
|
self.speaker_detector.set_max_speakers(max_speakers) |
|
|
|
|
|
|
|
asr_system = RealTimeASRDiarization() |
|
|
|
|
|
def process_audio_realtime(audio_data, threshold, max_speakers): |
|
"""Process audio in real-time""" |
|
global asr_system |
|
|
|
if audio_data is None: |
|
return asr_system.format_conversation_html(), get_status_display() |
|
|
|
|
|
asr_system.set_parameters(threshold, max_speakers) |
|
|
|
try: |
|
|
|
sr, audio_array = audio_data |
|
|
|
|
|
if audio_array.dtype != np.float32: |
|
audio_array = audio_array.astype(np.float32) |
|
if audio_array.dtype == np.int16: |
|
audio_array = audio_array / 32768.0 |
|
elif audio_array.dtype == np.int32: |
|
audio_array = audio_array / 2147483648.0 |
|
|
|
|
|
transcription, speaker_id, similarity = asr_system.process_audio_segment(audio_array, sr) |
|
|
|
if transcription and speaker_id is not None: |
|
|
|
asr_system.update_conversation(transcription, speaker_id) |
|
|
|
except Exception as e: |
|
print(f"Error processing audio: {e}") |
|
|
|
return asr_system.format_conversation_html(), get_status_display() |
|
|
|
|
|
def get_status_display(): |
|
"""Get formatted status display""" |
|
status = asr_system.get_status_info() |
|
|
|
status_html = f""" |
|
<div style="font-family: monospace; font-size: 12px;"> |
|
<strong>Status:</strong><br> |
|
Current Speaker: {status['current_speaker']}<br> |
|
Active Speakers: {status['active_speakers']} / {status['max_speakers']}<br> |
|
Total Segments: {status['total_segments']}<br> |
|
Threshold: {status['threshold']:.2f}<br> |
|
</div> |
|
""" |
|
|
|
return status_html |
|
|
|
|
|
def clear_conversation(): |
|
"""Clear the conversation""" |
|
global asr_system |
|
asr_system.clear_conversation() |
|
return asr_system.format_conversation_html(), get_status_display() |
|
|
|
|
|
def create_interface(): |
|
"""Create Gradio interface""" |
|
|
|
with gr.Blocks( |
|
title="Real-time ASR with Speaker Diarization", |
|
theme=gr.themes.Soft(), |
|
css=""" |
|
.conversation-box { |
|
height: 400px; |
|
overflow-y: auto; |
|
border: 1px solid #ddd; |
|
padding: 10px; |
|
background-color: #f9f9f9; |
|
} |
|
.status-box { |
|
border: 1px solid #ccc; |
|
padding: 10px; |
|
background-color: #f0f0f0; |
|
} |
|
""" |
|
) as demo: |
|
|
|
gr.Markdown( |
|
""" |
|
# 🎤 Real-time ASR with Live Speaker Diarization |
|
|
|
This application provides real-time speech recognition with speaker diarization. |
|
It can distinguish between different speakers and display their conversations in different colors. |
|
|
|
**Instructions:** |
|
1. Adjust the speaker change threshold and maximum speakers |
|
2. Click the microphone button to start recording |
|
3. Speak naturally - the system will detect speaker changes and transcribe speech |
|
4. Each speaker will be assigned a different color |
|
""" |
|
) |
|
|
|
with gr.Row(): |
|
with gr.Column(scale=3): |
|
|
|
conversation_display = gr.HTML( |
|
value="<p><i>Click the microphone to start recording...</i></p>", |
|
elem_classes=["conversation-box"] |
|
) |
|
|
|
|
|
audio_input = gr.Audio( |
|
source="microphone", |
|
type="numpy", |
|
streaming=True, |
|
label="🎤 Microphone Input" |
|
) |
|
|
|
with gr.Column(scale=1): |
|
|
|
gr.Markdown("### Controls") |
|
|
|
threshold_slider = gr.Slider( |
|
minimum=0.1, |
|
maximum=0.9, |
|
value=DEFAULT_CHANGE_THRESHOLD, |
|
step=0.05, |
|
label="Speaker Change Threshold", |
|
info="Higher values = less sensitive to speaker changes" |
|
) |
|
|
|
max_speakers_slider = gr.Slider( |
|
minimum=2, |
|
maximum=ABSOLUTE_MAX_SPEAKERS, |
|
value=DEFAULT_MAX_SPEAKERS, |
|
step=1, |
|
label="Maximum Speakers", |
|
info="Maximum number of different speakers to detect" |
|
) |
|
|
|
clear_btn = gr.Button("🗑️ Clear Conversation", variant="secondary") |
|
|
|
|
|
gr.Markdown("### Status") |
|
status_display = gr.HTML( |
|
value=get_status_display(), |
|
elem_classes=["status-box"] |
|
) |
|
|
|
|
|
gr.Markdown("### Speaker Colors") |
|
legend_html = "" |
|
for i in range(ABSOLUTE_MAX_SPEAKERS): |
|
color = SPEAKER_COLORS[i] |
|
name = SPEAKER_COLOR_NAMES[i] |
|
legend_html += f'<p><span style="color: {color}; font-weight: bold;">● Speaker {i+1} ({name})</span></p>' |
|
|
|
gr.HTML(legend_html) |
|
|
|
|
|
audio_input.change( |
|
fn=process_audio_realtime, |
|
inputs=[audio_input, threshold_slider, max_speakers_slider], |
|
outputs=[conversation_display, status_display], |
|
show_progress=False |
|
) |
|
|
|
clear_btn.click( |
|
fn=clear_conversation, |
|
outputs=[conversation_display, status_display] |
|
) |
|
|
|
|
|
demo.load( |
|
fn=lambda: (asr_system.format_conversation_html(), get_status_display()), |
|
outputs=[conversation_display, status_display], |
|
every=2 |
|
) |
|
|
|
return demo |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
demo = create_interface() |
|
demo.launch( |
|
server_name="0.0.0.0", |
|
server_port=7860, |
|
share=True |
|
) |
|
|