dmusingu's picture
Update app.py
823b2fe verified
raw
history blame
2.37 kB
from fastapi import FastAPI
from fastapi.responses import JSONResponse
import io
import joblib
import torch
import numpy as np
import torchvision.transforms as transforms
from PIL import Image
import yaml
import traceback
import timm
import logging
from fastapi.logger import logger
app = FastAPI()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def load_model():
# config = read_params(config_path)
model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3)
model_state_dict = torch.load('model.pth', map_location=device)
model.load_state_dict(model_state_dict)
model.eval()
return model
model = load_model()
def transform_image(image_bytes):
my_transforms = transforms.Compose([transforms.Resize(255),
transforms.CenterCrop(224),
transforms.ToTensor(),
transforms.Normalize(
[0.485, 0.456, 0.406],
[0.229, 0.224, 0.225])])
image = Image.open(io.BytesIO(image_bytes)).convert('RGB')
return my_transforms(image).unsqueeze(0)
def get_prediction(data):
tensor = transform_image(data)
# model = app.package['model']
with torch.no_grad():
prediction = model(tensor)
prediction = reverse_mapping[prediction.argmax().item()]
return prediction
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'}
def allowed_file(filename):
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS
@app.post("/predict")
async def predict(file: UploadFile = File(...)):
"""
Perform prediction on the uploaded image
"""
logger.info('API predict called')
if not allowed_file(file.filename):
raise HTTPException(status_code=400, detail="Format not supported")
try:
img_bytes = await file.read()
class_name = get_prediction(img_bytes)
logger.info(f'Prediction: {class_name}')
return JSONResponse(content={"class_name": class_name})
except Exception as e:
logger.error(f'Error: {str(e)}')
return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500)
@app.get("/")
def greet_json():
return {"Hello": "World!"}