|
import os |
|
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" |
|
|
|
from fastapi import FastAPI, UploadFile, File |
|
from transformers import ViTForImageClassification, ViTFeatureExtractor |
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms as transforms |
|
from PIL import Image |
|
import io |
|
|
|
app = FastAPI() |
|
|
|
|
|
model_name = "google/vit-base-patch16-224-in21k" |
|
model = ViTForImageClassification.from_pretrained(model_name) |
|
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) |
|
|
|
|
|
num_classes = 7 |
|
model.classifier = nn.Linear(model.config.hidden_size, num_classes) |
|
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) |
|
model.eval() |
|
|
|
|
|
|
|
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] |
|
|
|
|
|
thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854] |
|
|
|
|
|
transform = transforms.Compose([ |
|
transforms.Resize((224, 224)), |
|
transforms.ToTensor(), |
|
]) |
|
|
|
|
|
@app.post('/predict') |
|
async def predict(file: UploadFile = File(...)): |
|
contents = await file.read() |
|
image = Image.open(io.BytesIO(contents)) |
|
image = transform(image).unsqueeze(0) |
|
|
|
with torch.no_grad(): |
|
outputs = model(image) |
|
|
|
|
|
probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0] |
|
|
|
|
|
predicted_idx = torch.argmax(torch.tensor(probabilities)).item() |
|
predicted_label = class_labels[predicted_idx] |
|
predicted_probability = probabilities[predicted_idx] |
|
|
|
|
|
if predicted_probability < thresholds[predicted_idx]: |
|
return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)} |
|
else: |
|
return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)} |