File size: 5,125 Bytes
6924bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import traceback
import numpy as np
import torch
import base64
import io
import os
import logging
import whisper  # For audio loading/processing
import soundfile as sf
from inference import OmniInference
import tempfile

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI()

# Add CORS middleware
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

class AudioRequest(BaseModel):
    audio_data: str
    sample_rate: int

class AudioResponse(BaseModel):
    audio_data: str
    text: str = ""

# Model initialization status
INITIALIZATION_STATUS = {
    "model_loaded": False,
    "error": None
}

# Global model instance
model = None

def initialize_model():
    """Initialize the OmniInference model"""
    global model, INITIALIZATION_STATUS
    try:
        device = "cuda" if torch.cuda.is_available() else "cpu"
        logger.info(f"Initializing OmniInference model on device: {device}")
        
        ckpt_path = os.path.abspath('models')
        logger.info(f"Loading models from: {ckpt_path}")
        
        if not os.path.exists(ckpt_path):
            raise RuntimeError(f"Checkpoint path {ckpt_path} does not exist")
        
        model = OmniInference(ckpt_path, device=device)
        model.warm_up()
        
        INITIALIZATION_STATUS["model_loaded"] = True
        logger.info("OmniInference model initialized successfully")
        return True
    except Exception as e:
        INITIALIZATION_STATUS["error"] = str(e)
        logger.error(f"Failed to initialize model: {e}\n{traceback.format_exc()}")
        return False

@app.on_event("startup")
async def startup_event():
    """Initialize model on startup"""
    initialize_model()

@app.get("/api/v1/health")
def health_check():
    """Health check endpoint"""
    status = {
        "status": "healthy" if INITIALIZATION_STATUS["model_loaded"] else "initializing",
        "initialization_status": INITIALIZATION_STATUS
    }
    
    if model is not None:
        status.update({
            "device": str(model.device),
            "model_loaded": True,
            "warm_up_complete": True
        })
        
    return status

@app.post("/api/v1/inference")
async def inference(request: AudioRequest) -> AudioResponse:
    """Run inference with OmniInference model"""
    if not INITIALIZATION_STATUS["model_loaded"]:
        raise HTTPException(
            status_code=503,
            detail=f"Model not ready. Status: {INITIALIZATION_STATUS}"
        )
        
    try:
        logger.info(f"Received inference request with sample rate: {request.sample_rate}")
        
        # Decode audio data from base64 to numpy array
        audio_bytes = base64.b64decode(request.audio_data)
        audio_array = np.load(io.BytesIO(audio_bytes))
        
        # Save numpy array as temporary WAV file for OmniInference
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav:
            # Convert sample rate if needed (OmniInference expects 16kHz)
            if request.sample_rate != 16000:
                # You might want to add resampling logic here
                logger.warning("Sample rate conversion not implemented. Assuming 16kHz.")
            
            # Write WAV file using whisper's audio utilities
            audio_data = whisper.pad_or_trim(audio_array.flatten())  # Flatten to 1D if needed
            # whisper.save_audio(audio_data, temp_wav.name, sampling_rate=16000)
            sf.write(temp_wav.name, audio_data, 16000)
            
            # Run inference with streaming
            final_text = ""
            
            with tempfile.NamedTemporaryFile(suffix=".wav", delete=True) as temp_wav_out:
                # Get all results from generator
                for audio_stream, text_stream in model.run_AT_batch_stream(
                    temp_wav.name,
                    stream_stride=4,
                    max_returned_tokens=2048,
                    save_path=temp_wav_out.name,
                    sample_rate=request.sample_rate
                ):
                    if text_stream:  # Accumulate non-empty text
                        final_text += text_stream
                final_audio, sample_rate = sf.read(temp_wav_out.name)
                assert sample_rate == request.sample_rate

            # Encode output array to base64
            buffer = io.BytesIO()
            np.save(buffer, final_audio)
            audio_b64 = base64.b64encode(buffer.getvalue()).decode()

            return AudioResponse(
                audio_data=audio_b64,
                text=final_text.strip()
            )

    except Exception as e:
        logger.error(f"Inference failed: {str(e)}", exc_info=True)
        raise HTTPException(
            status_code=500, 
            detail=str(e)
        )

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)