Update app.py
Browse files
app.py
CHANGED
@@ -5,8 +5,7 @@ import numpy as np
|
|
5 |
import os
|
6 |
import uvicorn
|
7 |
|
8 |
-
|
9 |
-
app = FastAPI(docs_url=None, redoc_url=None)
|
10 |
|
11 |
# CORS configuration
|
12 |
app.add_middleware(
|
@@ -17,46 +16,42 @@ app.add_middleware(
|
|
17 |
)
|
18 |
|
19 |
# Load ONNX model
|
20 |
-
|
21 |
-
session = InferenceSession("model.onnx")
|
22 |
-
print("Model loaded successfully")
|
23 |
-
except Exception as e:
|
24 |
-
print(f"Model loading failed: {str(e)}")
|
25 |
-
raise
|
26 |
|
|
|
27 |
@app.get("/")
|
28 |
-
|
29 |
-
return {"status": "
|
30 |
|
31 |
-
|
|
|
32 |
async def predict(request: Request):
|
33 |
try:
|
34 |
-
# Get JSON input
|
35 |
data = await request.json()
|
36 |
-
|
37 |
-
# Convert to numpy arrays with correct shape
|
38 |
input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
|
39 |
attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
|
40 |
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
"input_ids": input_ids,
|
46 |
-
"attention_mask": attention_mask
|
47 |
-
}
|
48 |
-
)
|
49 |
|
50 |
return {"embedding": outputs[0].tolist()}
|
51 |
|
52 |
except Exception as e:
|
53 |
raise HTTPException(status_code=400, detail=str(e))
|
54 |
|
55 |
-
#
|
|
|
|
|
|
|
|
|
56 |
if __name__ == "__main__":
|
57 |
uvicorn.run(
|
58 |
-
|
59 |
host="0.0.0.0",
|
60 |
port=7860,
|
61 |
-
|
|
|
|
|
62 |
)
|
|
|
5 |
import os
|
6 |
import uvicorn
|
7 |
|
8 |
+
app = FastAPI(title="ONNX Model API")
|
|
|
9 |
|
10 |
# CORS configuration
|
11 |
app.add_middleware(
|
|
|
16 |
)
|
17 |
|
18 |
# Load ONNX model
|
19 |
+
session = InferenceSession("model.onnx")
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
+
# Essential for Spaces health checks
|
22 |
@app.get("/")
|
23 |
+
def read_root():
|
24 |
+
return {"status": "ONNX Model API is running"}
|
25 |
|
26 |
+
# Main prediction endpoint
|
27 |
+
@app.post("/predict")
|
28 |
async def predict(request: Request):
|
29 |
try:
|
|
|
30 |
data = await request.json()
|
|
|
|
|
31 |
input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
|
32 |
attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
|
33 |
|
34 |
+
outputs = session.run(None, {
|
35 |
+
"input_ids": input_ids,
|
36 |
+
"attention_mask": attention_mask
|
37 |
+
})
|
|
|
|
|
|
|
|
|
38 |
|
39 |
return {"embedding": outputs[0].tolist()}
|
40 |
|
41 |
except Exception as e:
|
42 |
raise HTTPException(status_code=400, detail=str(e))
|
43 |
|
44 |
+
# Special endpoint for Spaces compatibility
|
45 |
+
@app.post("/api/predict")
|
46 |
+
async def spaces_predict(request: Request):
|
47 |
+
return await predict(request)
|
48 |
+
|
49 |
if __name__ == "__main__":
|
50 |
uvicorn.run(
|
51 |
+
app,
|
52 |
host="0.0.0.0",
|
53 |
port=7860,
|
54 |
+
# Required for Spaces:
|
55 |
+
proxy_headers=True,
|
56 |
+
forwarded_allow_ips="*"
|
57 |
)
|