Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from fastapi import FastAPI | |
| from shared import DEFAULT_CHANGE_THRESHOLD, DEFAULT_MAX_SPEAKERS, ABSOLUTE_MAX_SPEAKERS | |
| print(gr.__version__) | |
| # Connection configuration (separate signaling server from model server) | |
| # These will be replaced at deployment time with the correct URLs | |
| RENDER_SIGNALING_URL = "wss://render-signal-audio.onrender.com/stream" | |
| HF_SPACE_URL = "https://androidguy-speaker-diarization.hf.space" | |
| def build_ui(): | |
| """Build Gradio UI for speaker diarization""" | |
| with gr.Blocks(title="Real-time Speaker Diarization", theme=gr.themes.Soft()) as demo: | |
| # Add configuration variables to page using custom component | |
| gr.HTML( | |
| f""" | |
| <!-- Configuration parameters --> | |
| <script> | |
| window.RENDER_SIGNALING_URL = "{RENDER_SIGNALING_URL}"; | |
| window.HF_SPACE_URL = "{HF_SPACE_URL}"; | |
| </script> | |
| """ | |
| ) | |
| # Header and description | |
| gr.Markdown("# π€ Live Speaker Diarization") | |
| gr.Markdown("Real-time speech recognition with automatic speaker identification") | |
| # Status indicator | |
| connection_status = gr.HTML( | |
| """<div class="status-indicator"> | |
| <span id="status-text" style="color:#888;">Waiting to connect...</span> | |
| <span id="status-icon" style="width:10px; height:10px; display:inline-block; | |
| background-color:#888; border-radius:50%; margin-left:5px;"></span> | |
| </div>""" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| # Conversation display with embedded JavaScript for WebRTC and audio handling | |
| conversation_display = gr.HTML( | |
| """ | |
| <div class='output' id="conversation" style='padding:20px; background:#111; border-radius:10px; | |
| min-height:400px; font-family:Arial; font-size:16px; line-height:1.5; overflow-y:auto;'> | |
| <i>Click 'Start Listening' to begin...</i> | |
| </div> | |
| <script> | |
| // Global variables | |
| let rtcConnection; | |
| let mediaStream; | |
| let wsConnection; | |
| let statusUpdateInterval; | |
| // Check connection to HF space | |
| async function checkHfConnection() { | |
| try { | |
| let response = await fetch(`${window.HF_SPACE_URL}/health`); | |
| return response.ok; | |
| } catch (err) { | |
| return false; | |
| } | |
| } | |
| // Start the connection and audio streaming | |
| async function startStreaming() { | |
| try { | |
| // Update status | |
| updateStatus('connecting'); | |
| // Request microphone access | |
| mediaStream = await navigator.mediaDevices.getUserMedia({audio: { | |
| echoCancellation: true, | |
| noiseSuppression: true, | |
| autoGainControl: true | |
| }}); | |
| // Set up WebRTC connection to Render signaling server | |
| await setupWebRTC(); | |
| // Also connect WebSocket directly to HF Space for conversation updates | |
| setupWebSocket(); | |
| // Start status update interval | |
| statusUpdateInterval = setInterval(updateConnectionInfo, 5000); | |
| // Update status | |
| updateStatus('connected'); | |
| document.getElementById("conversation").innerHTML = "<i>Connected! Start speaking...</i>"; | |
| } catch (err) { | |
| console.error('Error starting stream:', err); | |
| updateStatus('error', err.message); | |
| } | |
| } | |
| // Set up WebRTC connection to Render signaling server | |
| async function setupWebRTC() { | |
| try { | |
| if (rtcConnection) { | |
| rtcConnection.close(); | |
| } | |
| // Use FastRTC's connection approach | |
| const pc = new RTCPeerConnection({ | |
| iceServers: [{ urls: 'stun:stun.l.google.com:19302' }] | |
| }); | |
| // Add audio track | |
| mediaStream.getAudioTracks().forEach(track => { | |
| pc.addTrack(track, mediaStream); | |
| }); | |
| // Connect to FastRTC signaling via WebSocket | |
| const signalWs = new WebSocket(window.RENDER_SIGNALING_URL.replace('wss://', 'wss://')); | |
| // Handle signaling messages | |
| signalWs.onmessage = async (event) => { | |
| const message = JSON.parse(event.data); | |
| if (message.type === 'offer') { | |
| await pc.setRemoteDescription(new RTCSessionDescription(message)); | |
| const answer = await pc.createAnswer(); | |
| await pc.setLocalDescription(answer); | |
| signalWs.send(JSON.stringify(pc.localDescription)); | |
| } else if (message.type === 'candidate') { | |
| if (message.candidate) { | |
| await pc.addIceCandidate(new RTCIceCandidate(message)); | |
| } | |
| } | |
| }; | |
| // Send ICE candidates | |
| pc.onicecandidate = (event) => { | |
| if (event.candidate) { | |
| signalWs.send(JSON.stringify({ | |
| type: 'candidate', | |
| candidate: event.candidate | |
| })); | |
| } | |
| }; | |
| // Keep connection reference | |
| rtcConnection = pc; | |
| // Wait for connection to be established | |
| await new Promise((resolve, reject) => { | |
| const timeout = setTimeout(() => reject(new Error("WebRTC connection timeout")), 10000); | |
| pc.onconnectionstatechange = () => { | |
| if (pc.connectionState === 'connected') { | |
| clearTimeout(timeout); | |
| resolve(); | |
| } else if (pc.connectionState === 'failed' || pc.connectionState === 'disconnected') { | |
| clearTimeout(timeout); | |
| reject(new Error("WebRTC connection failed")); | |
| } | |
| }; | |
| }); | |
| updateStatus('connected'); | |
| } catch (err) { | |
| console.error('WebRTC setup error:', err); | |
| updateStatus('error', 'WebRTC setup failed: ' + err.message); | |
| } | |
| } | |
| // Set up WebSocket connection to HF Space for conversation updates | |
| function setupWebSocket() { | |
| const wsUrl = window.RENDER_SIGNALING_URL.replace('stream', 'ws_relay'); | |
| wsConnection = new WebSocket(wsUrl); | |
| wsConnection.onopen = () => { | |
| console.log('WebSocket connection established'); | |
| }; | |
| wsConnection.onmessage = (event) => { | |
| try { | |
| // Parse the JSON message | |
| const message = JSON.parse(event.data); | |
| // Process different message types | |
| switch(message.type) { | |
| case 'transcription': | |
| // Handle transcription data | |
| if (message.data && typeof message.data === 'object') { | |
| document.getElementById("conversation").innerHTML = message.data.conversation_html || | |
| JSON.stringify(message.data); | |
| } | |
| break; | |
| case 'processing_result': | |
| // Handle individual audio chunk processing result | |
| console.log('Processing result:', message.data); | |
| // Update status info if needed | |
| if (message.data && message.data.status === "processed") { | |
| const statusElem = document.getElementById('status-text'); | |
| if (statusElem) { | |
| const speakerId = message.data.speaker_id !== undefined ? | |
| `Speaker ${message.data.speaker_id + 1}` : ''; | |
| if (speakerId) { | |
| statusElem.textContent = `Connected - ${speakerId} active`; | |
| } | |
| } | |
| } else if (message.data && message.data.status === "error") { | |
| updateStatus('error', message.data.message || 'Processing error'); | |
| } | |
| break; | |
| case 'connection': | |
| console.log('Connection status:', message.status); | |
| updateStatus(message.status === 'connected' ? 'connected' : 'warning'); | |
| break; | |
| case 'connection_established': | |
| console.log('Connection established:', message); | |
| updateStatus('connected'); | |
| // If initial conversation is provided, display it | |
| if (message.conversation) { | |
| document.getElementById("conversation").innerHTML = message.conversation; | |
| } | |
| break; | |
| case 'conversation_update': | |
| if (message.conversation_html) { | |
| document.getElementById("conversation").innerHTML = message.conversation_html; | |
| } | |
| break; | |
| case 'conversation_cleared': | |
| document.getElementById("conversation").innerHTML = | |
| "<i>Conversation cleared. Start speaking again...</i>"; | |
| break; | |
| case 'error': | |
| console.error('Error message from server:', message.message); | |
| updateStatus('warning', message.message); | |
| break; | |
| default: | |
| // If it's just HTML content without proper JSON structure (legacy format) | |
| document.getElementById("conversation").innerHTML = event.data; | |
| } | |
| // Auto-scroll to bottom | |
| const container = document.getElementById("conversation"); | |
| container.scrollTop = container.scrollHeight; | |
| } catch (e) { | |
| // Fallback for non-JSON messages (legacy format) | |
| document.getElementById("conversation").innerHTML = event.data; | |
| // Auto-scroll to bottom | |
| const container = document.getElementById("conversation"); | |
| container.scrollTop = container.scrollHeight; | |
| } | |
| }; | |
| wsConnection.onerror = (error) => { | |
| console.error('WebSocket error:', error); | |
| updateStatus('warning', 'WebSocket error'); | |
| }; | |
| wsConnection.onclose = () => { | |
| console.log('WebSocket connection closed'); | |
| // Try to reconnect after a delay | |
| setTimeout(setupWebSocket, 3000); | |
| }; | |
| } | |
| // Update connection info in the UI | |
| async function updateConnectionInfo() { | |
| try { | |
| const hfConnected = await checkHfConnection(); | |
| if (!hfConnected) { | |
| updateStatus('warning', 'HF Space connection issue'); | |
| } else if (rtcConnection?.connectionState === 'connected' || | |
| rtcConnection?.iceConnectionState === 'connected') { | |
| updateStatus('connected'); | |
| } else { | |
| updateStatus('warning', 'Connection unstable'); | |
| } | |
| } catch (err) { | |
| console.error('Error updating connection info:', err); | |
| } | |
| } | |
| // Update status indicator | |
| function updateStatus(status, message = '') { | |
| const statusText = document.getElementById('status-text'); | |
| const statusIcon = document.getElementById('status-icon'); | |
| switch(status) { | |
| case 'connected': | |
| statusText.textContent = 'Connected'; | |
| statusIcon.style.backgroundColor = '#4CAF50'; | |
| break; | |
| case 'connecting': | |
| statusText.textContent = 'Connecting...'; | |
| statusIcon.style.backgroundColor = '#FFC107'; | |
| break; | |
| case 'disconnected': | |
| statusText.textContent = 'Disconnected'; | |
| statusIcon.style.backgroundColor = '#9E9E9E'; | |
| break; | |
| case 'error': | |
| statusText.textContent = 'Error: ' + message; | |
| statusIcon.style.backgroundColor = '#F44336'; | |
| break; | |
| case 'warning': | |
| statusText.textContent = 'Warning: ' + message; | |
| statusIcon.style.backgroundColor = '#FF9800'; | |
| break; | |
| default: | |
| statusText.textContent = 'Unknown'; | |
| statusIcon.style.backgroundColor = '#9E9E9E'; | |
| } | |
| } | |
| // Stop streaming and clean up | |
| function stopStreaming() { | |
| // Close WebRTC connection | |
| if (rtcConnection) { | |
| rtcConnection.close(); | |
| rtcConnection = null; | |
| } | |
| // Close WebSocket | |
| if (wsConnection) { | |
| wsConnection.close(); | |
| wsConnection = null; | |
| } | |
| // Stop all tracks in media stream | |
| if (mediaStream) { | |
| mediaStream.getTracks().forEach(track => track.stop()); | |
| mediaStream = null; | |
| } | |
| // Clear interval | |
| if (statusUpdateInterval) { | |
| clearInterval(statusUpdateInterval); | |
| statusUpdateInterval = null; | |
| } | |
| // Update status | |
| updateStatus('disconnected'); | |
| } | |
| // Set up event listeners when the DOM is loaded | |
| document.addEventListener('DOMContentLoaded', () => { | |
| updateStatus('disconnected'); | |
| }); | |
| </script> | |
| """, | |
| label="Live Conversation" | |
| ) | |
| # Control buttons | |
| with gr.Row(): | |
| start_btn = gr.Button("βΆοΈ Start Listening", variant="primary", size="lg") | |
| stop_btn = gr.Button("βΉοΈ Stop", variant="stop", size="lg") | |
| clear_btn = gr.Button("ποΈ Clear", variant="secondary", size="lg") | |
| # Status display | |
| status_output = gr.Markdown( | |
| """ | |
| ## System Status | |
| Waiting to connect... | |
| *Click Start Listening to begin* | |
| """, | |
| label="Status Information" | |
| ) | |
| with gr.Column(scale=1): | |
| # Settings | |
| gr.Markdown("## βοΈ Settings") | |
| threshold_slider = gr.Slider( | |
| minimum=0.3, | |
| maximum=0.9, | |
| step=0.05, | |
| value=DEFAULT_CHANGE_THRESHOLD, | |
| label="Speaker Change Sensitivity", | |
| info="Lower = more sensitive (more speaker changes)" | |
| ) | |
| max_speakers_slider = gr.Slider( | |
| minimum=2, | |
| maximum=ABSOLUTE_MAX_SPEAKERS, | |
| step=1, | |
| value=DEFAULT_MAX_SPEAKERS, | |
| label="Maximum Speakers" | |
| ) | |
| update_btn = gr.Button("Update Settings", variant="secondary") | |
| # Instructions | |
| gr.Markdown(""" | |
| ## π Instructions | |
| 1. **Start Listening** - allows browser to access microphone | |
| 2. **Speak** - system will transcribe and identify speakers | |
| 3. **Stop** when finished | |
| 4. **Clear** to reset conversation | |
| ## π¨ Speaker Colors | |
| - π΄ Speaker 1 (Red) | |
| - π’ Speaker 2 (Teal) | |
| - π΅ Speaker 3 (Blue) | |
| - π‘ Speaker 4 (Green) | |
| - β Speaker 5 (Yellow) | |
| - π£ Speaker 6 (Plum) | |
| - π€ Speaker 7 (Mint) | |
| - π Speaker 8 (Gold) | |
| """) | |
| # JavaScript to connect buttons to the script functions | |
| gr.HTML(""" | |
| <script> | |
| // Wait for Gradio to fully load | |
| document.addEventListener('DOMContentLoaded', () => { | |
| // Wait a bit for Gradio buttons to be created | |
| setTimeout(() => { | |
| // Get the buttons | |
| const startBtn = document.querySelector('button[aria-label="Start Listening"]'); | |
| const stopBtn = document.querySelector('button[aria-label="Stop"]'); | |
| const clearBtn = document.querySelector('button[aria-label="Clear"]'); | |
| if (startBtn) startBtn.onclick = () => startStreaming(); | |
| if (stopBtn) stopBtn.onclick = () => stopStreaming(); | |
| if (clearBtn) clearBtn.onclick = () => { | |
| // Make API call to clear conversation | |
| fetch(`${window.HF_SPACE_URL}/clear`, { | |
| method: 'POST' | |
| }).then(resp => resp.json()) | |
| .then(data => { | |
| document.getElementById("conversation").innerHTML = | |
| "<i>Conversation cleared. Start speaking again...</i>"; | |
| }); | |
| } | |
| // Set up settings update | |
| const updateBtn = document.querySelector('button[aria-label="Update Settings"]'); | |
| if (updateBtn) updateBtn.onclick = () => { | |
| const threshold = document.querySelector('input[aria-label="Speaker Change Sensitivity"]').value; | |
| const maxSpeakers = document.querySelector('input[aria-label="Maximum Speakers"]').value; | |
| fetch(`${window.HF_SPACE_URL}/settings?threshold=${threshold}&max_speakers=${maxSpeakers}`, { | |
| method: 'POST' | |
| }).then(resp => resp.json()) | |
| .then(data => { | |
| const statusOutput = document.querySelector('.prose'); | |
| if (statusOutput) { | |
| statusOutput.innerHTML = ` | |
| <h2>System Status</h2> | |
| <p>Settings updated:</p> | |
| <ul> | |
| <li>Threshold: ${threshold}</li> | |
| <li>Max Speakers: ${maxSpeakers}</li> | |
| </ul> | |
| `; | |
| } | |
| }); | |
| } | |
| }, 1000); | |
| }); | |
| </script> | |
| """) | |
| # Set up periodic status updates | |
| def get_status(): | |
| """API call to get system status - called periodically""" | |
| import requests | |
| try: | |
| resp = requests.get(f"{HF_SPACE_URL}/status") | |
| if resp.status_code == 200: | |
| return resp.json().get('status', 'No status information') | |
| return "Error getting status" | |
| except Exception as e: | |
| return f"Connection error: {str(e)}" | |
| status_timer = gr.Timer(5) | |
| status_timer.tick(fn=get_status, outputs=status_output) | |
| return demo | |
| # Create Gradio interface | |
| demo = build_ui() | |
| def mount_ui(app: FastAPI): | |
| """Mount Gradio app to FastAPI""" | |
| app.mount("/ui", demo.app) | |
| # For standalone testing | |
| if __name__ == "__main__": | |
| demo.launch() |