Anwarkh1 commited on
Commit
bf03d3b
·
verified ·
1 Parent(s): 88b4498

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -2
app.py CHANGED
@@ -1,3 +1,50 @@
1
- import gradio as gr
 
 
 
 
 
 
2
 
3
- gr.load("models/Anwarkh1/Skin_Cancer-Image_Classification").launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI, UploadFile, File
2
+ from transformers import ViTForImageClassification, ViTFeatureExtractor
3
+ import torch
4
+ import torch.nn as nn
5
+ import torchvision.transforms as transforms
6
+ from PIL import Image
7
+ import io
8
 
9
+ app = FastAPI()
10
+
11
+ # Load the ViT model and its feature extractor
12
+ model_name = "google/vit-base-patch16-224-in21k"
13
+ model = ViTForImageClassification.from_pretrained(model_name)
14
+ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
15
+
16
+ # Load the trained model weights
17
+ num_classes = 7
18
+ model.classifier = nn.Linear(model.config.hidden_size, num_classes)
19
+ # Load the trained weights
20
+ model.load_state_dict(torch.load("models/Anwarkh1/Skin_Cancer-Image_Classification", map_location=torch.device('cpu')))
21
+ model.eval()
22
+
23
+ # Define class labels
24
+ class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
25
+
26
+ # Define image transformations
27
+ transform = transforms.Compose([
28
+ transforms.Resize((224, 224)),
29
+ transforms.ToTensor(),
30
+ ])
31
+
32
+ # Define API endpoint for model inference
33
+ @app.post('/predict')
34
+ async def predict(file: UploadFile = File(...)):
35
+ contents = await file.read()
36
+ image = Image.open(io.BytesIO(contents))
37
+ image = transform(image).unsqueeze(0) # Add batch dimension
38
+
39
+ with torch.no_grad():
40
+ outputs = model(image)
41
+
42
+ # Calculate softmax probabilities
43
+ probabilities = torch.softmax(outputs.logits, dim=1)
44
+
45
+ # Get predicted class index and its probability
46
+ predicted_idx = torch.argmax(probabilities).item()
47
+ predicted_label = class_labels[predicted_idx]
48
+ predicted_accuracy = probabilities[0][predicted_idx].item()
49
+
50
+ return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}