File size: 2,309 Bytes
ce3c4cd bf03d3b ae4ebd6 bf03d3b 88b4498 bf03d3b ae4ebd6 bf03d3b ae4ebd6 f41dca6 ae4ebd6 bf03d3b 49db250 bf03d3b 49db250 bf03d3b 49db250 bf03d3b f41dca6 bf03d3b 49db250 bf03d3b 49db250 bf03d3b 49db250 bf03d3b 49db250 01c57ef 49db250 01c57ef |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
app = FastAPI()
# Load the ViT model and its feature extractor
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Load the trained model weights
num_classes = 7
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
model.eval()
# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
# Define optimal thresholds
thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854]
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Define API endpoint for model inference with class-specific thresholds
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = transform(image).unsqueeze(0) # Add batch dimension and move to device
with torch.no_grad():
outputs = model(image)
# Calculate softmax probabilities
probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
# Get predicted class index and its probability
predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
predicted_label = class_labels[predicted_idx]
predicted_probability = probabilities[predicted_idx]
# Check if the predicted probability meets the class-specific threshold
if predicted_probability < thresholds[predicted_idx]:
return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)}
else:
return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)} |