File size: 1,888 Bytes
ce3c4cd bf03d3b ae4ebd6 bf03d3b 88b4498 bf03d3b ae4ebd6 bf03d3b ae4ebd6 bf03d3b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 |
import os
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
app = FastAPI()
# Load the ViT model and its feature extractor
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Load the trained model weights
num_classes = 7
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
# Load the trained weights
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
model.eval()
# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Define API endpoint for model inference
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
# Calculate softmax probabilities
probabilities = torch.softmax(outputs.logits, dim=1)
# Get predicted class index and its probability
predicted_idx = torch.argmax(probabilities).item()
predicted_label = class_labels[predicted_idx]
predicted_accuracy = probabilities[0][predicted_idx].item()
return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}
|