from PIL import Image from torchvision import transforms import torch import os from fastapi import FastAPI, HTTPException from fastapi.responses import JSONResponse from io import BytesIO # from dotenv import load_dotenv from .model import MalwareNet, malware_classes # assuming malware_classes contains class names # load_dotenv() app = FastAPI() # Preprocessing function for the model def preprocess_image(image_path): image = Image.open(image_path).convert("RGB") preprocess = transforms.Compose([ transforms.Resize((224, 224)), # Resize to the input size expected by the model transforms.ToTensor(), # Convert to tensor transforms.Normalize( # Normalize using model's requirements (e.g. ImageNet) mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ) ]) return preprocess(image).unsqueeze(0) # Add batch dimension # Load model and its weights def load_model(): model = MalwareNet() base_dir = os.path.dirname(os.path.abspath(__file__)) model_location = os.path.join(base_dir, '../model/malwareNet.pt') # Relative path to the model file state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True) model.load_state_dict(state_dict) model.eval() # Set the model to evaluation mode return model @app.get("/") def status(): return {"status": "ok"} @app.post("/predict") async def predict(data: dict): image_path = data.get("image_url") if not os.path.exists(image_path): raise HTTPException(status_code=400, detail="Image path does not exist.") try: # Load and preprocess the image img_tensor = preprocess_image(image_path) # Load the model and make the prediction model = load_model() with torch.no_grad(): # No gradient calculation is needed prediction = model(img_tensor) # Get the predicted class predicted_class = malware_classes[torch.argmax(prediction).item()] return JSONResponse(content={"image": image_path, "prediction": predicted_class}) except Exception as e: raise HTTPException(status_code=500, detail=f"Error processing the image: {e}") if __name__ == "__main__": import uvicorn uvicorn.run( "src.serve:app", host=os.environ.get("HOST", "localhost"), port=int(os.environ.get("PORT", 5000)), reload=True, )