Spaces:
Sleeping
Sleeping
Sushan
commited on
Commit
·
8b2caaf
1
Parent(s):
7aa2125
changed the pytorch model
Browse files- app.py +30 -11
- model.pth +3 -0
- requirements.txt +1 -2
app.py
CHANGED
@@ -1,29 +1,48 @@
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
|
|
3 |
import pandas as pd
|
4 |
-
import joblib
|
5 |
|
6 |
-
#
|
7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
|
9 |
app = FastAPI()
|
10 |
|
11 |
-
#
|
12 |
app.add_middleware(
|
13 |
CORSMiddleware,
|
14 |
-
allow_origins=["*"], # Allow all origins
|
15 |
allow_credentials=True,
|
16 |
-
allow_methods=["*"], # Allow all methods
|
17 |
allow_headers=["*"], # Allow all headers
|
18 |
)
|
19 |
|
20 |
@app.post("/predict")
|
21 |
async def predict(features: dict):
|
22 |
-
# Convert the input
|
23 |
input_data = pd.DataFrame([features])
|
|
|
24 |
|
25 |
-
# Make prediction
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
|
|
|
1 |
from fastapi import FastAPI
|
2 |
from fastapi.middleware.cors import CORSMiddleware
|
3 |
+
import torch
|
4 |
import pandas as pd
|
|
|
5 |
|
6 |
+
# Define the model structure (ensure this matches your model class)
|
7 |
+
class AsteroidModel(torch.nn.Module):
|
8 |
+
def __init__(self):
|
9 |
+
super(AsteroidModel, self).__init__()
|
10 |
+
# Define the layers as per your original model architecture
|
11 |
+
self.fc1 = torch.nn.Linear(5, 16)
|
12 |
+
self.fc2 = torch.nn.Linear(16, 8)
|
13 |
+
self.fc3 = torch.nn.Linear(8, 1)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
x = torch.relu(self.fc1(x))
|
17 |
+
x = torch.relu(self.fc2(x))
|
18 |
+
x = torch.sigmoid(self.fc3(x))
|
19 |
+
return x
|
20 |
+
|
21 |
+
# Initialize the model and load the saved weights
|
22 |
+
model = AsteroidModel()
|
23 |
+
model.load_state_dict(torch.load('model.pth'))
|
24 |
+
model.eval() # Set model to evaluation mode
|
25 |
|
26 |
app = FastAPI()
|
27 |
|
28 |
+
# CORS middleware to handle cross-origin requests
|
29 |
app.add_middleware(
|
30 |
CORSMiddleware,
|
31 |
+
allow_origins=["*"], # Allow all origins, adjust if needed
|
32 |
allow_credentials=True,
|
33 |
+
allow_methods=["*"], # Allow all methods
|
34 |
allow_headers=["*"], # Allow all headers
|
35 |
)
|
36 |
|
37 |
@app.post("/predict")
|
38 |
async def predict(features: dict):
|
39 |
+
# Convert the input to a tensor
|
40 |
input_data = pd.DataFrame([features])
|
41 |
+
input_tensor = torch.tensor(input_data.values, dtype=torch.float32)
|
42 |
|
43 |
+
# Make prediction
|
44 |
+
with torch.no_grad():
|
45 |
+
output = model(input_tensor).squeeze()
|
46 |
+
prediction = (output > 0.5).float().item() # Convert to binary prediction
|
47 |
|
48 |
+
return {"is_potentially_hazardous_asteroid": int(prediction)}
|
model.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:750fa82060eb6b41c0be0db37365efb3c7ed2af1cdb5095491ec6a9a7ceae8ca
|
3 |
+
size 3252
|
requirements.txt
CHANGED
@@ -1,5 +1,4 @@
|
|
1 |
fastapi
|
2 |
uvicorn
|
3 |
pandas
|
4 |
-
|
5 |
-
joblib
|
|
|
1 |
fastapi
|
2 |
uvicorn
|
3 |
pandas
|
4 |
+
torch
|
|