Spaces:
Sleeping
Sleeping
File size: 2,451 Bytes
98d74fb |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
from PIL import Image
from torchvision import transforms
import torch
import os
from fastapi import FastAPI, HTTPException
from fastapi.responses import JSONResponse
from io import BytesIO
# from dotenv import load_dotenv
from .model import MalwareNet, malware_classes # assuming malware_classes contains class names
# load_dotenv()
app = FastAPI()
# Preprocessing function for the model
def preprocess_image(image_path):
image = Image.open(image_path).convert("RGB")
preprocess = transforms.Compose([
transforms.Resize((224, 224)), # Resize to the input size expected by the model
transforms.ToTensor(), # Convert to tensor
transforms.Normalize( # Normalize using model's requirements (e.g. ImageNet)
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225]
)
])
return preprocess(image).unsqueeze(0) # Add batch dimension
# Load model and its weights
def load_model():
model = MalwareNet()
base_dir = os.path.dirname(os.path.abspath(__file__))
model_location = os.path.join(base_dir, '../model/malwareNet.pt') # Relative path to the model file
state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True)
model.load_state_dict(state_dict)
model.eval() # Set the model to evaluation mode
return model
@app.get("/")
def status():
return {"status": "ok"}
@app.post("/predict")
async def predict(data: dict):
image_path = data.get("image_url")
if not os.path.exists(image_path):
raise HTTPException(status_code=400, detail="Image path does not exist.")
try:
# Load and preprocess the image
img_tensor = preprocess_image(image_path)
# Load the model and make the prediction
model = load_model()
with torch.no_grad(): # No gradient calculation is needed
prediction = model(img_tensor)
# Get the predicted class
predicted_class = malware_classes[torch.argmax(prediction).item()]
return JSONResponse(content={"image": image_path, "prediction": predicted_class})
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error processing the image: {e}")
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.serve:app",
host=os.environ.get("HOST", "localhost"),
port=int(os.environ.get("PORT", 5000)),
reload=True,
) |