File size: 2,420 Bytes
97a4ae5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
from fastapi import FastAPI, WebSocket
from fastapi.middleware.cors import CORSMiddleware
from shared import RealtimeSpeakerDiarization
import uvicorn
import logging

# Set up logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# Initialize FastAPI app
app = FastAPI()

# Add CORS middleware for browser compatibility
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Initialize the diarization system
diart = RealtimeSpeakerDiarization()
success = diart.initialize_models()
logger.info(f"Models initialized: {success}")
diart.start_recording()

@app.get("/health")
async def health_check():
    return {"status": "healthy", "system_running": diart.is_running}

@app.websocket("/ws_inference")
async def ws_inference(ws: WebSocket):
    """WebSocket endpoint for real-time audio processing"""
    await ws.accept()
    logger.info("WebSocket connection established")
    
    try:
        async for chunk in ws.iter_bytes():
            # Process audio data
            diart.process_audio_chunk(chunk, sample_rate=16000)
            
            # Send back conversation results
            result = diart.get_formatted_conversation()
            await ws.send_text(result)
    except Exception as e:
        logger.error(f"WebSocket error: {e}")
    finally:
        logger.info("WebSocket connection closed")

@app.get("/conversation")
async def get_conversation():
    """Get the current conversation as HTML"""
    return {"conversation": diart.get_formatted_conversation()}

@app.get("/status")
async def get_status():
    """Get system status information"""
    return {"status": diart.get_status_info()}

@app.post("/settings")
async def update_settings(threshold: float, max_speakers: int):
    """Update speaker detection settings"""
    result = diart.update_settings(threshold, max_speakers)
    return {"result": result}

@app.post("/clear")
async def clear_conversation():
    """Clear the conversation"""
    result = diart.clear_conversation()
    return {"result": result}

# Import UI module to mount the Gradio app
try:
    import ui
    ui.mount_ui(app)
    logger.info("Gradio UI mounted successfully")
except ImportError:
    logger.warning("UI module not found, running in API-only mode")

if __name__ == "__main__":
    uvicorn.run(app, host="0.0.0.0", port=7860)