z4hid's picture
source code added
98d74fb verified
raw
history blame
2.45 kB
from PIL import Image
from torchvision import transforms
import torch
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from io import BytesIO
# from dotenv import load_dotenv
from .model import MalwareNet, malware_classes # assuming malware_classes contains class names
# load_dotenv()
app = FastAPI()
# Preprocessing function for the model
def preprocess_image(image_path):
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize to the input size expected by the model
transforms.ToTensor(), # Convert to tensor
transforms.Normalize( # Normalize using model's requirements (e.g. ImageNet)
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return preprocess(image).unsqueeze(0) # Add batch dimension
# Load model and its weights
def load_model():
model = MalwareNet()
base_dir = os.path.dirname(os.path.abspath(__file__))
model_location = os.path.join(base_dir, '../model/malwareNet.pt') # Relative path to the model file
state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval() # Set the model to evaluation mode
return model
@app.get("/")
def status():
return {"status": "ok"}
@app.post("/predict")
async def predict(data: dict):
image_path = data.get("image_url")
if not os.path.exists(image_path):
raise HTTPException(status_code=400, detail="Image path does not exist.")
try:
# Load and preprocess the image
img_tensor = preprocess_image(image_path)
# Load the model and make the prediction
model = load_model()
with torch.no_grad(): # No gradient calculation is needed
prediction = model(img_tensor)
# Get the predicted class
predicted_class = malware_classes[torch.argmax(prediction).item()]
return JSONResponse(content={"image": image_path, "prediction": predicted_class})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the image: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.serve:app",
host=os.environ.get("HOST", "localhost"),
port=int(os.environ.get("PORT", 5000)),
reload=True,
)