orreryspaceapp / app.py
Sushan
hope it works
6f5baa5
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
import torch
import pandas as pd
# Define the model structure (ensure this matches your model class)
class AsteroidModel(torch.nn.Module):
def __init__(self):
super(AsteroidModel, self).__init__()
# Define the layers as per your original model architecture
self.fc1 = torch.nn.Linear(5, 16)
self.fc2 = torch.nn.Linear(16, 8)
self.fc3 = torch.nn.Linear(8, 1)
def forward(self, x):
x = torch.relu(self.fc1(x))
x = torch.relu(self.fc2(x))
x = torch.sigmoid(self.fc3(x))
return x
# Initialize the model and load the saved weights
model = AsteroidModel()
model.load_state_dict(torch.load('model.pth', weights_only=True))
model.eval() # Set model to evaluation mode
app = FastAPI()
# CORS middleware to handle cross-origin requests
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins, adjust if needed
allow_credentials=True,
allow_methods=["*"], # Allow all methods
allow_headers=["*"], # Allow all headers
)
@app.post("/predict")
async def predict(features: dict):
# Convert the input to a tensor
input_data = pd.DataFrame([features])
input_tensor = torch.tensor(input_data.values, dtype=torch.float32)
# Make prediction
with torch.no_grad():
output = model(input_tensor).squeeze()
prediction = (output > 0.5).float().item() # Convert to binary prediction
return {"is_potentially_hazardous_asteroid": int(prediction)}
import os
if __name__ == "__main__":
import uvicorn
port = int(os.environ.get("PORT", 7860)) # Set the default port to 7860
uvicorn.run("app:app", host="0.0.0.0", port=port)