File size: 2,309 Bytes
ce3c4cd
 
 
bf03d3b
ae4ebd6
bf03d3b
 
 
 
 
88b4498
bf03d3b
 
 
ae4ebd6
bf03d3b
ae4ebd6
 
 
 
 
f41dca6
ae4ebd6
bf03d3b
49db250
bf03d3b
 
 
49db250
 
 
bf03d3b
 
 
 
 
 
49db250
bf03d3b
 
 
 
f41dca6
bf03d3b
 
 
 
 
49db250
bf03d3b
 
49db250
bf03d3b
49db250
bf03d3b
49db250
 
01c57ef
49db250
01c57ef
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache"  # Set cache directory to a writable location

from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io

app = FastAPI()

# Load the ViT model and its feature extractor
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)

# Load the trained model weights
num_classes = 7
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
model.eval()


# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']

# Define optimal thresholds
thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854]

# Define image transformations
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

# Define API endpoint for model inference with class-specific thresholds
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
    contents = await file.read()
    image = Image.open(io.BytesIO(contents))
    image = transform(image).unsqueeze(0)  # Add batch dimension and move to device
    
    with torch.no_grad():
        outputs = model(image)
    
    # Calculate softmax probabilities
    probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
    
    # Get predicted class index and its probability
    predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
    predicted_label = class_labels[predicted_idx]
    predicted_probability = probabilities[predicted_idx]
    
    # Check if the predicted probability meets the class-specific threshold
    if predicted_probability < thresholds[predicted_idx]:
        return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)}
    else:
        return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)}