Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import torch | |
| import pandas as pd | |
| # Define the model structure (ensure this matches your model class) | |
| class AsteroidModel(torch.nn.Module): | |
| def __init__(self): | |
| super(AsteroidModel, self).__init__() | |
| # Define the layers as per your original model architecture | |
| self.fc1 = torch.nn.Linear(5, 16) | |
| self.fc2 = torch.nn.Linear(16, 8) | |
| self.fc3 = torch.nn.Linear(8, 1) | |
| def forward(self, x): | |
| x = torch.relu(self.fc1(x)) | |
| x = torch.relu(self.fc2(x)) | |
| x = torch.sigmoid(self.fc3(x)) | |
| return x | |
| # Initialize the model and load the saved weights | |
| model = AsteroidModel() | |
| model.load_state_dict(torch.load('model.pth', weights_only=True)) | |
| model.eval() # Set model to evaluation mode | |
| app = FastAPI() | |
| # CORS middleware to handle cross-origin requests | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], # Allow all origins, adjust if needed | |
| allow_credentials=True, | |
| allow_methods=["*"], # Allow all methods | |
| allow_headers=["*"], # Allow all headers | |
| ) | |
| async def predict(features: dict): | |
| # Convert the input to a tensor | |
| input_data = pd.DataFrame([features]) | |
| input_tensor = torch.tensor(input_data.values, dtype=torch.float32) | |
| # Make prediction | |
| with torch.no_grad(): | |
| output = model(input_tensor).squeeze() | |
| prediction = (output > 0.5).float().item() # Convert to binary prediction | |
| return {"is_potentially_hazardous_asteroid": int(prediction)} | |
| import os | |
| if __name__ == "__main__": | |
| import uvicorn | |
| port = int(os.environ.get("PORT", 7860)) # Set the default port to 7860 | |
| uvicorn.run("app:app", host="0.0.0.0", port=port) | |