chryzxc commited on
Commit
4afa954
·
verified ·
1 Parent(s): 658ebc3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -25
app.py CHANGED
@@ -5,8 +5,7 @@ import numpy as np
5
  import os
6
  import uvicorn
7
 
8
- # Initialize FastAPI with docs disabled for Spaces
9
- app = FastAPI(docs_url=None, redoc_url=None)
10
 
11
  # CORS configuration
12
  app.add_middleware(
@@ -17,46 +16,42 @@ app.add_middleware(
17
  )
18
 
19
  # Load ONNX model
20
- try:
21
- session = InferenceSession("model.onnx")
22
- print("Model loaded successfully")
23
- except Exception as e:
24
- print(f"Model loading failed: {str(e)}")
25
- raise
26
 
 
27
  @app.get("/")
28
- async def health_check():
29
- return {"status": "ready", "model": "onnx"}
30
 
31
- @app.post("/api/predict")
 
32
  async def predict(request: Request):
33
  try:
34
- # Get JSON input
35
  data = await request.json()
36
-
37
- # Convert to numpy arrays with correct shape
38
  input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
39
  attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
40
 
41
- # Run inference
42
- outputs = session.run(
43
- None,
44
- {
45
- "input_ids": input_ids,
46
- "attention_mask": attention_mask
47
- }
48
- )
49
 
50
  return {"embedding": outputs[0].tolist()}
51
 
52
  except Exception as e:
53
  raise HTTPException(status_code=400, detail=str(e))
54
 
55
- # Required for Hugging Face Spaces
 
 
 
 
56
  if __name__ == "__main__":
57
  uvicorn.run(
58
- "app:app",
59
  host="0.0.0.0",
60
  port=7860,
61
- reload=False
 
 
62
  )
 
5
  import os
6
  import uvicorn
7
 
8
+ app = FastAPI(title="ONNX Model API")
 
9
 
10
  # CORS configuration
11
  app.add_middleware(
 
16
  )
17
 
18
  # Load ONNX model
19
+ session = InferenceSession("model.onnx")
 
 
 
 
 
20
 
21
+ # Essential for Spaces health checks
22
  @app.get("/")
23
+ def read_root():
24
+ return {"status": "ONNX Model API is running"}
25
 
26
+ # Main prediction endpoint
27
+ @app.post("/predict")
28
  async def predict(request: Request):
29
  try:
 
30
  data = await request.json()
 
 
31
  input_ids = np.array(data["input_ids"], dtype=np.int64).reshape(1, -1)
32
  attention_mask = np.array(data["attention_mask"], dtype=np.int64).reshape(1, -1)
33
 
34
+ outputs = session.run(None, {
35
+ "input_ids": input_ids,
36
+ "attention_mask": attention_mask
37
+ })
 
 
 
 
38
 
39
  return {"embedding": outputs[0].tolist()}
40
 
41
  except Exception as e:
42
  raise HTTPException(status_code=400, detail=str(e))
43
 
44
+ # Special endpoint for Spaces compatibility
45
+ @app.post("/api/predict")
46
+ async def spaces_predict(request: Request):
47
+ return await predict(request)
48
+
49
  if __name__ == "__main__":
50
  uvicorn.run(
51
+ app,
52
  host="0.0.0.0",
53
  port=7860,
54
+ # Required for Spaces:
55
+ proxy_headers=True,
56
+ forwarded_allow_ips="*"
57
  )