import os os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location from fastapi import FastAPI, UploadFile, File from transformers import ViTForImageClassification, ViTFeatureExtractor import torch import torch.nn as nn import torchvision.transforms as transforms from PIL import Image import io app = FastAPI() # Load the ViT model and its feature extractor model_name = "google/vit-base-patch16-224-in21k" model = ViTForImageClassification.from_pretrained(model_name) feature_extractor = ViTFeatureExtractor.from_pretrained(model_name) # Load the trained model weights num_classes = 7 model.classifier = nn.Linear(model.config.hidden_size, num_classes) # Load the trained weights model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu'))) model.eval() # Define class labels class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma'] # Define image transformations transform = transforms.Compose([ transforms.Resize((224, 224)), transforms.ToTensor(), ]) # Define API endpoint for model inference @app.post('/predict') async def predict(file: UploadFile = File(...)): contents = await file.read() image = Image.open(io.BytesIO(contents)) image = transform(image).unsqueeze(0) # Add batch dimension with torch.no_grad(): outputs = model(image) # Calculate softmax probabilities probabilities = torch.softmax(outputs.logits, dim=1) # Get predicted class index and its probability predicted_idx = torch.argmax(probabilities).item() predicted_label = class_labels[predicted_idx] predicted_accuracy = probabilities[0][predicted_idx].item() return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}