from fastapi import FastAPI, Request, HTTPException from fastapi.middleware.cors import CORSMiddleware from onnxruntime import InferenceSession import numpy as np import os import uvicorn # Initialize FastAPI with docs disabled for Spaces app = FastAPI(docs_url=None, redoc_url=None) # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Load ONNX model try: session = InferenceSession("model.onnx") print("Model loaded successfully") except Exception as e: print(f"Model loading failed: {str(e)}") raise @app.get("/") async def health_check(): return {"status": "ready", "model": "onnx"} @app.post("/api/predict") async def predict(request: Request): try: # Get JSON input data = await request.json() # Convert to numpy arrays with correct shape input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1) attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1) # Run inference outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask } ) return {"embedding": outputs[0].tolist()} except Exception as e: raise HTTPException(status_code=400, detail=str(e)) # Required for Hugging Face Spaces if __name__ == "__main__": uvicorn.run( "app:app", host="0.0.0.0", port=7860, reload=False )