from fastapi import FastAPI, HTTPException, Request from onnxruntime import InferenceSession from transformers import AutoTokenizer import numpy as np import os import uvicorn app = FastAPI() # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained( "Xenova/multi-qa-mpnet-base-dot-v1", use_fast=True, legacy=False ) # Load ONNX model try: session = InferenceSession("model.onnx") print("Model loaded successfully") except Exception as e: print(f"Failed to load model: {str(e)}") raise @app.get("/") def health_check(): return {"status": "OK", "model": "ONNX"} @app.post("/api/predict") async def predict(request: Request): try: # Get JSON input data = await request.json() text = data.get("text", "") if not text: raise HTTPException(status_code=400, detail="No text provided") # Tokenize input inputs = tokenizer( text, return_tensors="np", padding="max_length", truncation=True, max_length=32 ) # Prepare ONNX inputs with correct shapes onnx_inputs = { "input_ids": inputs["input_ids"].astype(np.int64), "attention_mask": inputs["attention_mask"].astype(np.int64) } # Run inference outputs = session.run(None, onnx_inputs) # Convert outputs to list and handle numpy types embedding = outputs[0][0].astype(float).tolist() # First output, first batch return { "embedding": embedding, "tokens": tokenizer.convert_ids_to_tokens(inputs["input_ids"][0]) } except Exception as e: raise HTTPException(status_code=500, detail=str(e)) if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=7860, reload=False )