Spaces:
Sleeping
Sleeping
from fastapi import FastAPI | |
from fastapi.responses import JSONResponse | |
import io | |
import joblib | |
import torch | |
import numpy as np | |
import torchvision.transforms as transforms | |
from PIL import Image | |
import yaml | |
import traceback | |
import timm | |
import logging | |
from fastapi.logger import logger | |
app = FastAPI() | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
def load_model(): | |
# config = read_params(config_path) | |
model = timm.create_model('convnext_base.clip_laiona', pretrained=True, num_classes=3) | |
model_state_dict = torch.load('model.pth', map_location=device) | |
model.load_state_dict(model_state_dict) | |
model.eval() | |
return model | |
model = load_model() | |
def transform_image(image_bytes): | |
my_transforms = transforms.Compose([transforms.Resize(255), | |
transforms.CenterCrop(224), | |
transforms.ToTensor(), | |
transforms.Normalize( | |
[0.485, 0.456, 0.406], | |
[0.229, 0.224, 0.225])]) | |
image = Image.open(io.BytesIO(image_bytes)).convert('RGB') | |
return my_transforms(image).unsqueeze(0) | |
def get_prediction(data): | |
tensor = transform_image(data) | |
# model = app.package['model'] | |
with torch.no_grad(): | |
prediction = model(tensor) | |
prediction = reverse_mapping[prediction.argmax().item()] | |
return prediction | |
ALLOWED_EXTENSIONS = {'png', 'jpg', 'jpeg'} | |
def allowed_file(filename): | |
return '.' in filename and filename.rsplit('.', 1)[1].lower() in ALLOWED_EXTENSIONS | |
async def predict(file: UploadFile = File(...)): | |
""" | |
Perform prediction on the uploaded image | |
""" | |
logger.info('API predict called') | |
if not allowed_file(file.filename): | |
raise HTTPException(status_code=400, detail="Format not supported") | |
try: | |
img_bytes = await file.read() | |
class_name = get_prediction(img_bytes) | |
logger.info(f'Prediction: {class_name}') | |
return JSONResponse(content={"class_name": class_name}) | |
except Exception as e: | |
logger.error(f'Error: {str(e)}') | |
return JSONResponse(content={"error": str(e), "trace": traceback.format_exc()}, status_code=500) | |
def greet_json(): | |
return {"Hello": "World!"} |