|
from fastapi import FastAPI, HTTPException |
|
from fastapi.middleware.cors import CORSMiddleware |
|
from pydantic import BaseModel |
|
import traceback |
|
import numpy as np |
|
import torch |
|
import base64 |
|
import io |
|
import os |
|
import logging |
|
import whisper |
|
import soundfile as sf |
|
from inference import OmniInference |
|
import tempfile |
|
|
|
logging.basicConfig(level=logging.INFO) |
|
logger = logging.getLogger(__name__) |
|
|
|
app = FastAPI() |
|
|
|
|
|
app.add_middleware( |
|
CORSMiddleware, |
|
allow_origins=["*"], |
|
allow_credentials=True, |
|
allow_methods=["*"], |
|
allow_headers=["*"], |
|
) |
|
|
|
class AudioRequest(BaseModel): |
|
audio_data: str |
|
sample_rate: int |
|
|
|
class AudioResponse(BaseModel): |
|
audio_data: str |
|
text: str = "" |
|
|
|
|
|
INITIALIZATION_STATUS = { |
|
"model_loaded": False, |
|
"error": None |
|
} |
|
|
|
|
|
model = None |
|
|
|
def initialize_model(): |
|
"""Initialize the OmniInference model""" |
|
global model, INITIALIZATION_STATUS |
|
try: |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
logger.info(f"Initializing OmniInference model on device: {device}") |
|
|
|
ckpt_path = os.path.abspath('models') |
|
logger.info(f"Loading models from: {ckpt_path}") |
|
|
|
if not os.path.exists(ckpt_path): |
|
raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist") |
|
|
|
model = OmniInference(ckpt_path, device=device) |
|
model.warm_up() |
|
|
|
INITIALIZATION_STATUS["model_loaded"] = True |
|
logger.info("OmniInference model initialized successfully") |
|
return True |
|
except Exception as e: |
|
INITIALIZATION_STATUS["error"] = str(e) |
|
logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}") |
|
return False |
|
|
|
@app.on_event("startup") |
|
async def startup_event(): |
|
"""Initialize model on startup""" |
|
initialize_model() |
|
|
|
@app.get("/api/v1/health") |
|
def health_check(): |
|
"""Health check endpoint""" |
|
status = { |
|
"status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing", |
|
"initialization_status": INITIALIZATION_STATUS |
|
} |
|
|
|
if model is not None: |
|
status.update({ |
|
"device": str(model.device), |
|
"model_loaded": True, |
|
"warm_up_complete": True |
|
}) |
|
|
|
return status |
|
|
|
@app.post("/api/v1/inference") |
|
async def inference(request: AudioRequest) -> AudioResponse: |
|
"""Run inference with OmniInference model""" |
|
if not INITIALIZATION_STATUS["model_loaded"]: |
|
raise HTTPException( |
|
status_code=503, |
|
detail=f"Model not ready. Status: {INITIALIZATION_STATUS}" |
|
) |
|
|
|
try: |
|
logger.info(f"Received inference request with sample rate: {request.sample_rate}") |
|
|
|
|
|
audio_bytes = base64.b64decode(request.audio_data) |
|
audio_array = np.load(io.BytesIO(audio_bytes)) |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav: |
|
|
|
if request.sample_rate != 16000: |
|
|
|
logger.warning("Sample rate conversion not implemented. Assuming 16kHz.") |
|
|
|
|
|
audio_data = whisper.pad_or_trim(audio_array.flatten()) |
|
|
|
sf.write(temp_wav.name, audio_data, 16000) |
|
|
|
|
|
final_text = "" |
|
|
|
with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out: |
|
|
|
for audio_stream, text_stream in model.run_AT_batch_stream( |
|
temp_wav.name, |
|
stream_stride=4, |
|
max_returned_tokens=2048, |
|
save_path=temp_wav_out.name, |
|
sample_rate=request.sample_rate |
|
): |
|
if text_stream: |
|
final_text += text_stream |
|
final_audio, sample_rate = sf.read(temp_wav_out.name) |
|
assert sample_rate == request.sample_rate |
|
|
|
|
|
buffer = io.BytesIO() |
|
np.save(buffer, final_audio) |
|
audio_b64 = base64.b64encode(buffer.getvalue()).decode() |
|
|
|
return AudioResponse( |
|
audio_data=audio_b64, |
|
text=final_text.strip() |
|
) |
|
|
|
except Exception as e: |
|
logger.error(f"Inference failed: {str(e)}", exc_info=True) |
|
raise HTTPException( |
|
status_code=500, |
|
detail=str(e) |
|
) |
|
|
|
if __name__ == "__main__": |
|
import uvicorn |
|
uvicorn.run(app, host="0.0.0.0", port=8000) |
|
|