Update main.py
Browse files
main.py
CHANGED
@@ -19,12 +19,9 @@ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
|
|
19 |
# Load the trained model weights
|
20 |
num_classes = 7
|
21 |
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
|
22 |
-
model.load_state_dict(torch.load("
|
23 |
model.eval()
|
24 |
|
25 |
-
# Define device
|
26 |
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
27 |
-
model.to(device)
|
28 |
|
29 |
# Define class labels
|
30 |
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
|
@@ -43,7 +40,7 @@ transform = transforms.Compose([
|
|
43 |
async def predict(file: UploadFile = File(...)):
|
44 |
contents = await file.read()
|
45 |
image = Image.open(io.BytesIO(contents))
|
46 |
-
image = transform(image).unsqueeze(0)
|
47 |
|
48 |
with torch.no_grad():
|
49 |
outputs = model(image)
|
|
|
19 |
# Load the trained model weights
|
20 |
num_classes = 7
|
21 |
model.classifier = nn.Linear(model.config.hidden_size, num_classes)
|
22 |
+
model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
|
23 |
model.eval()
|
24 |
|
|
|
|
|
|
|
25 |
|
26 |
# Define class labels
|
27 |
class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
|
|
|
40 |
async def predict(file: UploadFile = File(...)):
|
41 |
contents = await file.read()
|
42 |
image = Image.open(io.BytesIO(contents))
|
43 |
+
image = transform(image).unsqueeze(0) # Add batch dimension and move to device
|
44 |
|
45 |
with torch.no_grad():
|
46 |
outputs = model(image)
|