Sushan commited on
Commit
8b2caaf
·
1 Parent(s): 7aa2125

changed the pytorch model

Browse files
Files changed (3) hide show
  1. app.py +30 -11
  2. model.pth +3 -0
  3. requirements.txt +1 -2
app.py CHANGED
@@ -1,29 +1,48 @@
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
 
3
  import pandas as pd
4
- import joblib
5
 
6
- # Load the trained model
7
- model = joblib.load("model.pkl") # Ensure your model is saved as 'model.pkl'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
  app = FastAPI()
10
 
11
- # Add CORS middleware to allow requests from any origin
12
  app.add_middleware(
13
  CORSMiddleware,
14
- allow_origins=["*"], # Allow all origins (adjust if needed)
15
  allow_credentials=True,
16
- allow_methods=["*"], # Allow all methods (GET, POST, etc.)
17
  allow_headers=["*"], # Allow all headers
18
  )
19
 
20
  @app.post("/predict")
21
  async def predict(features: dict):
22
- # Convert the input into a DataFrame
23
  input_data = pd.DataFrame([features])
 
24
 
25
- # Make prediction using the trained model
26
- prediction = model.predict(input_data)
27
-
28
- return {"is_potentially_hazardous_asteroid": int(prediction[0])}
29
 
 
 
1
  from fastapi import FastAPI
2
  from fastapi.middleware.cors import CORSMiddleware
3
+ import torch
4
  import pandas as pd
 
5
 
6
+ # Define the model structure (ensure this matches your model class)
7
+ class AsteroidModel(torch.nn.Module):
8
+ def __init__(self):
9
+ super(AsteroidModel, self).__init__()
10
+ # Define the layers as per your original model architecture
11
+ self.fc1 = torch.nn.Linear(5, 16)
12
+ self.fc2 = torch.nn.Linear(16, 8)
13
+ self.fc3 = torch.nn.Linear(8, 1)
14
+
15
+ def forward(self, x):
16
+ x = torch.relu(self.fc1(x))
17
+ x = torch.relu(self.fc2(x))
18
+ x = torch.sigmoid(self.fc3(x))
19
+ return x
20
+
21
+ # Initialize the model and load the saved weights
22
+ model = AsteroidModel()
23
+ model.load_state_dict(torch.load('model.pth'))
24
+ model.eval() # Set model to evaluation mode
25
 
26
  app = FastAPI()
27
 
28
+ # CORS middleware to handle cross-origin requests
29
  app.add_middleware(
30
  CORSMiddleware,
31
+ allow_origins=["*"], # Allow all origins, adjust if needed
32
  allow_credentials=True,
33
+ allow_methods=["*"], # Allow all methods
34
  allow_headers=["*"], # Allow all headers
35
  )
36
 
37
  @app.post("/predict")
38
  async def predict(features: dict):
39
+ # Convert the input to a tensor
40
  input_data = pd.DataFrame([features])
41
+ input_tensor = torch.tensor(input_data.values, dtype=torch.float32)
42
 
43
+ # Make prediction
44
+ with torch.no_grad():
45
+ output = model(input_tensor).squeeze()
46
+ prediction = (output > 0.5).float().item() # Convert to binary prediction
47
 
48
+ return {"is_potentially_hazardous_asteroid": int(prediction)}
model.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:750fa82060eb6b41c0be0db37365efb3c7ed2af1cdb5095491ec6a9a7ceae8ca
3
+ size 3252
requirements.txt CHANGED
@@ -1,5 +1,4 @@
1
  fastapi
2
  uvicorn
3
  pandas
4
- scikit-learn
5
- joblib
 
1
  fastapi
2
  uvicorn
3
  pandas
4
+ torch