Anwarkh1's picture
Update main.py
01c57ef verified
raw
history blame
2.31 kB
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
app = FastAPI()
# Load the ViT model and its feature extractor
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Load the trained model weights
num_classes = 7
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
model.eval()
# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
# Define optimal thresholds
thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854]
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Define API endpoint for model inference with class-specific thresholds
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = transform(image).unsqueeze(0) # Add batch dimension and move to device
with torch.no_grad():
outputs = model(image)
# Calculate softmax probabilities
probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
# Get predicted class index and its probability
predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
predicted_label = class_labels[predicted_idx]
predicted_probability = probabilities[predicted_idx]
# Check if the predicted probability meets the class-specific threshold
if predicted_probability < thresholds[predicted_idx]:
return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)}
else:
return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)}