File size: 1,768 Bytes
7aa2125
 
8b2caaf
7aa2125
 
8b2caaf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f5baa5
8b2caaf
7aa2125
 
 
8b2caaf
7aa2125
 
8b2caaf
7aa2125
8b2caaf
7aa2125
 
 
 
 
8b2caaf
7aa2125
8b2caaf
7aa2125
8b2caaf
 
 
 
7aa2125
8b2caaf
6f5baa5
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import torch
import pandas as pd

# Define the model structure (ensure this matches your model class)
class AsteroidModel(torch.nn.Module):
    def __init__(self):
        super(AsteroidModel, self).__init__()
        # Define the layers as per your original model architecture
        self.fc1 = torch.nn.Linear(5, 16)
        self.fc2 = torch.nn.Linear(16, 8)
        self.fc3 = torch.nn.Linear(8, 1)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = torch.relu(self.fc2(x))
        x = torch.sigmoid(self.fc3(x))
        return x

# Initialize the model and load the saved weights
model = AsteroidModel()
model.load_state_dict(torch.load('model.pth', weights_only=True))
model.eval()  # Set model to evaluation mode

app = FastAPI()

# CORS middleware to handle cross-origin requests
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],  # Allow all origins, adjust if needed
    allow_credentials=True,
    allow_methods=["*"],  # Allow all methods
    allow_headers=["*"],  # Allow all headers
)

@app.post("/predict")
async def predict(features: dict):
    # Convert the input to a tensor
    input_data = pd.DataFrame([features])
    input_tensor = torch.tensor(input_data.values, dtype=torch.float32)
    
    # Make prediction
    with torch.no_grad():
        output = model(input_tensor).squeeze()
        prediction = (output > 0.5).float().item()  # Convert to binary prediction

    return {"is_potentially_hazardous_asteroid": int(prediction)}

import os

if __name__ == "__main__":
    import uvicorn
    port = int(os.environ.get("PORT", 7860))  # Set the default port to 7860
    uvicorn.run("app:app", host="0.0.0.0", port=port)