Anwarkh1's picture
Rename app.py to main.py
b3c5c13 verified
raw
history blame
1.8 kB
from fastapi import FastAPI, UploadFile, File
from transformers import ViTForImageClassification, ViTFeatureExtractor
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from PIL import Image
import io
app = FastAPI()
# Load the ViT model and its feature extractor
model_name = "google/vit-base-patch16-224-in21k"
model = ViTForImageClassification.from_pretrained(model_name)
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
# Load the trained model weights
num_classes = 7
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
# Load the trained weights
model.load_state_dict(torch.load("models/Anwarkh1/Skin_Cancer-Image_Classification", map_location=torch.device('cpu')))
model.eval()
# Define class labels
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
# Define image transformations
transform = transforms.Compose([
transforms.Resize((224, 224)),
transforms.ToTensor(),
])
# Define API endpoint for model inference
@app.post('/predict')
async def predict(file: UploadFile = File(...)):
contents = await file.read()
image = Image.open(io.BytesIO(contents))
image = transform(image).unsqueeze(0) # Add batch dimension
with torch.no_grad():
outputs = model(image)
# Calculate softmax probabilities
probabilities = torch.softmax(outputs.logits, dim=1)
# Get predicted class index and its probability
predicted_idx = torch.argmax(probabilities).item()
predicted_label = class_labels[predicted_idx]
predicted_accuracy = probabilities[0][predicted_idx].item()
return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}