import os os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location from fastapi import FastAPI, UploadFile, File from transformers import ViTForImageClassification, ViTFeatureExtractor import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import io app = FastAPI() # Load the ViT model and its feature extractor model_name = "google/vit-base-patch16-224-in21k" model = ViTForImageClassification.from_pretrained(model_name) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) # Load the trained model weights num_classes = 7 model.classifier = nn.Linear(model.config.hidden_size, num_classes) model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) model.eval() # Define class labels class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] # Define optimal thresholds thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854] # Define image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Define API endpoint for model inference with class-specific thresholds @app.post('/predict') async def predict(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)) image = transform(image).unsqueeze(0) # Add batch dimension and move to device with torch.no_grad(): outputs = model(image) # Calculate softmax probabilities probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0] # Get predicted class index and its probability predicted_idx = torch.argmax(torch.tensor(probabilities)).item() predicted_label = class_labels[predicted_idx] predicted_probability = probabilities[predicted_idx] # Check if the predicted probability meets the class-specific threshold if predicted_probability < thresholds[predicted_idx]: return {'predicted_class': 'uncertain', 'accuracy': float(predicted_probability)} else: return {'predicted_class': predicted_label, 'accuracy': float(predicted_probability)}