v2v-omni-21 / server.py
unrented5443's picture
Duplicate from unrented5443/sn21-v2v
7ed5306 verified
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import traceback
import numpy as np
import torch
import base64
import io
import os
import logging
import whisper # For audio loading/processing
import soundfile as sf
from inference import OmniInference
import tempfile
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI()
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class AudioRequest(BaseModel):
audio_data: str
sample_rate: int
class AudioResponse(BaseModel):
audio_data: str
text: str = ""
# Model initialization status
INITIALIZATION_STATUS = {
"model_loaded": False,
"error": None
}
# Global model instance
model = None
def initialize_model():
"""Initialize the OmniInference model"""
global model, INITIALIZATION_STATUS
try:
device = "cuda" if torch.cuda.is_available() else "cpu"
logger.info(f"Initializing OmniInference model on device: {device}")
ckpt_path = os.path.abspath('models')
logger.info(f"Loading models from: {ckpt_path}")
if not os.path.exists(ckpt_path):
raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist")
model = OmniInference(ckpt_path, device=device)
model.warm_up()
INITIALIZATION_STATUS["model_loaded"] = True
logger.info("OmniInference model initialized successfully")
return True
except Exception as e:
INITIALIZATION_STATUS["error"] = str(e)
logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}")
return False
@app.on_event("startup")
async def startup_event():
"""Initialize model on startup"""
initialize_model()
@app.get("/api/v1/health")
def health_check():
"""Health check endpoint"""
status = {
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
"initialization_status": INITIALIZATION_STATUS
}
if model is not None:
status.update({
"device": str(model.device),
"model_loaded": True,
"warm_up_complete": True
})
return status
@app.post("/api/v1/inference")
async def inference(request: AudioRequest) -> AudioResponse:
"""Run inference with OmniInference model"""
if not INITIALIZATION_STATUS["model_loaded"]:
raise HTTPException(
status_code=503,
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
)
try:
logger.info(f"Received inference request with sample rate: {request.sample_rate}")
# Decode audio data from base64 to numpy array
audio_bytes = base64.b64decode(request.audio_data)
audio_array = np.load(io.BytesIO(audio_bytes))
# Save numpy array as temporary WAV file for OmniInference
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
# Convert sample rate if needed (OmniInference expects 16kHz)
if request.sample_rate != 16000:
# You might want to add resampling logic here
logger.warning("Sample rate conversion not implemented. Assuming 16kHz.")
# Write WAV file using whisper's audio utilities
audio_data = whisper.pad_or_trim(audio_array.flatten()) # Flatten to 1D if needed
# whisper.save_audio(audio_data, temp_wav.name, sampling_rate=16000)
sf.write(temp_wav.name, audio_data, 16000)
# Run inference with streaming
final_text = ""
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out:
# Get all results from generator
for audio_stream, text_stream in model.run_AT_batch_stream(
temp_wav.name,
stream_stride=4,
max_returned_tokens=2048,
save_path=temp_wav_out.name,
sample_rate=request.sample_rate
):
if text_stream: # Accumulate non-empty text
final_text += text_stream
final_audio, sample_rate = sf.read(temp_wav_out.name)
assert sample_rate == request.sample_rate
# Encode output array to base64
buffer = io.BytesIO()
np.save(buffer, final_audio)
audio_b64 = base64.b64encode(buffer.getvalue()).decode()
return AudioResponse(
audio_data=audio_b64,
text=final_text.strip()
)
except Exception as e:
logger.error(f"Inference failed: {str(e)}", exc_info=True)
raise HTTPException(
status_code=500,
detail=str(e)
)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)