Update main.py
Browse files
main.py
CHANGED
@@ -2,7 +2,7 @@ import os
|
|
2 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
|
3 |
|
4 |
from fastapi import FastAPI, UploadFile, File
|
5 |
-
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torchvision.transforms as transforms
|
@@ -12,9 +12,16 @@ import io
|
|
12 |
app = FastAPI()
|
13 |
|
14 |
# Load the ViT model and its feature extractor
|
15 |
-
model_name = "
|
16 |
model = ViTForImageClassification.from_pretrained(model_name)
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Define class labels
|
20 |
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
|
|
|
2 |
os.environ["TRANSFORMERS_CACHE"] = "/tmp/huggingface_cache" # Set cache directory to a writable location
|
3 |
|
4 |
from fastapi import FastAPI, UploadFile, File
|
5 |
+
from transformers import ViTForImageClassification, ViTFeatureExtractor
|
6 |
import torch
|
7 |
import torch.nn as nn
|
8 |
import torchvision.transforms as transforms
|
|
|
12 |
app = FastAPI()
|
13 |
|
14 |
# Load the ViT model and its feature extractor
|
15 |
+
model_name = "google/vit-base-patch16-224-in21k"
|
16 |
model = ViTForImageClassification.from_pretrained(model_name)
|
17 |
+
feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
18 |
+
|
19 |
+
# Load the trained model weights
|
20 |
+
num_classes = 7
|
21 |
+
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
|
22 |
+
# Load the trained weights
|
23 |
+
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
|
24 |
+
model.eval()
|
25 |
|
26 |
# Define class labels
|
27 |
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
|