from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware import torch import pandas as pd # Define the model structure (ensure this matches your model class) class AsteroidModel(torch.nn.Module): def __init__(self): super(AsteroidModel, self).__init__() # Define the layers as per your original model architecture self.fc1 = torch.nn.Linear(5, 16) self.fc2 = torch.nn.Linear(16, 8) self.fc3 = torch.nn.Linear(8, 1) def forward(self, x): x = torch.relu(self.fc1(x)) x = torch.relu(self.fc2(x)) x = torch.sigmoid(self.fc3(x)) return x # Initialize the model and load the saved weights model = AsteroidModel() model.load_state_dict(torch.load('model.pth', weights_only=True)) model.eval() # Set model to evaluation mode app = FastAPI() # CORS middleware to handle cross-origin requests app.add_middleware( CORSMiddleware, allow_origins=["*"], # Allow all origins, adjust if needed allow_credentials=True, allow_methods=["*"], # Allow all methods allow_headers=["*"], # Allow all headers ) @app.post("/predict") async def predict(features: dict): # Convert the input to a tensor input_data = pd.DataFrame([features]) input_tensor = torch.tensor(input_data.values, dtype=torch.float32) # Make prediction with torch.no_grad(): output = model(input_tensor).squeeze() prediction = (output > 0.5).float().item() # Convert to binary prediction return {"is_potentially_hazardous_asteroid": int(prediction)} import os if __name__ == "__main__": import uvicorn port = int(os.environ.get("PORT", 7860)) # Set the default port to 7860 uvicorn.run("app:app", host="0.0.0.0", port=port)