from fastapi import FastAPI, HTTPException from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import traceback import numpy as np import torch import base64 import io import os import logging import whisper # For audio loading/processing import soundfile as sf from inference import OmniInference import tempfile logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI() # Add CORS middleware app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) class AudioRequest(BaseModel): audio_data: str sample_rate: int class AudioResponse(BaseModel): audio_data: str text: str = "" # Model initialization status INITIALIZATION_STATUS = { "model_loaded": False, "error": None } # Global model instance model = None def initialize_model(): """Initialize the OmniInference model""" global model, INITIALIZATION_STATUS try: device = "cuda" if torch.cuda.is_available() else "cpu" logger.info(f"Initializing OmniInference model on device: {device}") ckpt_path = os.path.abspath('models') logger.info(f"Loading models from: {ckpt_path}") if not os.path.exists(ckpt_path): raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist") model = OmniInference(ckpt_path, device=device) model.warm_up() INITIALIZATION_STATUS["model_loaded"] = True logger.info("OmniInference model initialized successfully") return True except Exception as e: INITIALIZATION_STATUS["error"] = str(e) logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}") return False @app.on_event("startup") async def startup_event(): """Initialize model on startup""" initialize_model() @app.get("/api/v1/health") def health_check(): """Health check endpoint""" status = { "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", "initialization_status": INITIALIZATION_STATUS } if model is not None: status.update({ "device": str(model.device), "model_loaded": True, "warm_up_complete": True }) return status @app.post("/api/v1/inference") async def inference(request: AudioRequest) -> AudioResponse: """Run inference with OmniInference model""" if not INITIALIZATION_STATUS["model_loaded"]: raise HTTPException( status_code=503, detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" ) try: logger.info(f"Received inference request with sample rate: {request.sample_rate}") # Decode audio data from base64 to numpy array audio_bytes = base64.b64decode(request.audio_data) audio_array = np.load(io.BytesIO(audio_bytes)) # Save numpy array as temporary WAV file for OmniInference with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: # Convert sample rate if needed (OmniInference expects 16kHz) if request.sample_rate != 16000: # You might want to add resampling logic here logger.warning("Sample rate conversion not implemented. Assuming 16kHz.") # Write WAV file using whisper's audio utilities audio_data = whisper.pad_or_trim(audio_array.flatten()) # Flatten to 1D if needed # whisper.save_audio(audio_data, temp_wav.name, sampling_rate=16000) sf.write(temp_wav.name, audio_data, 16000) # Run inference with streaming final_text = "" with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out: # Get all results from generator for audio_stream, text_stream in model.run_AT_batch_stream( temp_wav.name, stream_stride=4, max_returned_tokens=2048, save_path=temp_wav_out.name, sample_rate=request.sample_rate ): if text_stream: # Accumulate non-empty text final_text += text_stream final_audio, sample_rate = sf.read(temp_wav_out.name) assert sample_rate == request.sample_rate # Encode output array to base64 buffer = io.BytesIO() np.save(buffer, final_audio) audio_b64 = base64.b64encode(buffer.getvalue()).decode() return AudioResponse( audio_data=audio_b64, text=final_text.strip() ) except Exception as e: logger.error(f"Inference failed: {str(e)}", exc_info=True) raise HTTPException( status_code=500, detail=str(e) ) if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=8000)