Spaces:
Sleeping
Sleeping
from PIL import Image | |
from torchvision import transforms | |
import torch | |
import os | |
from fastapi import FastAPI, HTTPException | |
from fastapi.responses import JSONResponse | |
from io import BytesIO | |
# from dotenv import load_dotenv | |
from .model import MalwareNet, malware_classes # assuming malware_classes contains class names | |
# load_dotenv() | |
app = FastAPI() | |
# Preprocessing function for the model | |
def preprocess_image(image_path): | |
image = Image.open(image_path).convert("RGB") | |
preprocess = transforms.Compose([ | |
transforms.Resize((224, 224)), # Resize to the input size expected by the model | |
transforms.ToTensor(), # Convert to tensor | |
transforms.Normalize( # Normalize using model's requirements (e.g. ImageNet) | |
mean=[0.485, 0.456, 0.406], | |
std=[0.229, 0.224, 0.225] | |
) | |
]) | |
return preprocess(image).unsqueeze(0) # Add batch dimension | |
# Load model and its weights | |
def load_model(): | |
model = MalwareNet() | |
base_dir = os.path.dirname(os.path.abspath(__file__)) | |
model_location = os.path.join(base_dir, '../model/malwareNet.pt') # Relative path to the model file | |
state_dict = torch.load(model_location, map_location=torch.device('cpu'), weights_only=True) | |
model.load_state_dict(state_dict) | |
model.eval() # Set the model to evaluation mode | |
return model | |
def status(): | |
return {"status": "ok"} | |
async def predict(data: dict): | |
image_path = data.get("image_url") | |
if not os.path.exists(image_path): | |
raise HTTPException(status_code=400, detail="Image path does not exist.") | |
try: | |
# Load and preprocess the image | |
img_tensor = preprocess_image(image_path) | |
# Load the model and make the prediction | |
model = load_model() | |
with torch.no_grad(): # No gradient calculation is needed | |
prediction = model(img_tensor) | |
# Get the predicted class | |
predicted_class = malware_classes[torch.argmax(prediction).item()] | |
return JSONResponse(content={"image": image_path, "prediction": predicted_class}) | |
except Exception as e: | |
raise HTTPException(status_code=500, detail=f"Error processing the image: {e}") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run( | |
"src.serve:app", | |
host=os.environ.get("HOST", "localhost"), | |
port=int(os.environ.get("PORT", 5000)), | |
reload=True, | |
) |