from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from onnxruntime import InferenceSession import numpy as np import os app = FastAPI(title="ONNX Model API") # CORS configuration app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # Load ONNX model model_path = os.path.join(os.getcwd(), "model.onnx") session = InferenceSession(model_path) @app.get("/") def health_check(): return {"status": "healthy", "message": "ONNX model is ready"} @app.post("/predict") async def predict(inputs: dict): """Expects {'input_ids': [], 'attention_mask': []}""" try: input_ids = np.array(inputs["input_ids"], dtype=np.int64).reshape(1, -1) attention_mask = np.array(inputs["attention_mask"], dtype=np.int64).reshape(1, -1) outputs = session.run( None, { "input_ids": input_ids, "attention_mask": attention_mask } ) return {"embedding": outputs[0].tolist()} except Exception as e: return {"error": str(e)} # Required for Hugging Face Spaces if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860)