Anwarkh1 commited on
Commit
49db250
·
verified ·
1 Parent(s): 1e7eb1a

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +18 -8
main.py CHANGED
@@ -19,35 +19,45 @@ feature_extractor = ViTFeatureExtractor.from_pretrained(model_name)
19
  # Load the trained model weights
20
  num_classes = 7
21
  model.classifier = nn.Linear(model.config.hidden_size, num_classes)
22
- # Load the trained weights
23
- model.load_state_dict(torch.load("skin_cancer_model.pth", map_location=torch.device('cpu')))
24
  model.eval()
25
 
 
 
 
 
26
  # Define class labels
27
  class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
28
 
 
 
 
29
  # Define image transformations
30
  transform = transforms.Compose([
31
  transforms.Resize((224, 224)),
32
  transforms.ToTensor(),
33
  ])
34
 
35
- # Define API endpoint for model inference
36
  @app.post('/predict')
37
  async def predict(file: UploadFile = File(...)):
38
  contents = await file.read()
39
  image = Image.open(io.BytesIO(contents))
40
- image = transform(image).unsqueeze(0) # Add batch dimension
41
 
42
  with torch.no_grad():
43
  outputs = model(image)
44
 
45
  # Calculate softmax probabilities
46
- probabilities = torch.softmax(outputs.logits, dim=1)
47
 
48
  # Get predicted class index and its probability
49
- predicted_idx = torch.argmax(probabilities).item()
50
  predicted_label = class_labels[predicted_idx]
51
- predicted_accuracy = probabilities[0][predicted_idx].item()
52
 
53
- return {'predicted_class': predicted_label, 'accuracy': predicted_accuracy}
 
 
 
 
 
19
  # Load the trained model weights
20
  num_classes = 7
21
  model.classifier = nn.Linear(model.config.hidden_size, num_classes)
22
+ model.load_state_dict(torch.load("/kaggle/input/skincancer-vit/skin_cancer_model.pth", map_location=torch.device('cpu')))
 
23
  model.eval()
24
 
25
+ # Define device
26
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
27
+ model.to(device)
28
+
29
  # Define class labels
30
  class_labels = ['benign_keratosis-like_lesions', 'basal_cell_carcinoma', 'actinic_keratoses', 'vascular_lesions', 'melanocytic_Nevi', 'melanoma', 'dermatofibroma']
31
 
32
+ # Define optimal thresholds
33
+ thresholds = [0.88134295, 0.43095806, 0.39622146, 0.90647435, 0.8128958, 0.05310565, 0.15926854]
34
+
35
  # Define image transformations
36
  transform = transforms.Compose([
37
  transforms.Resize((224, 224)),
38
  transforms.ToTensor(),
39
  ])
40
 
41
+ # Define API endpoint for model inference with class-specific thresholds
42
  @app.post('/predict')
43
  async def predict(file: UploadFile = File(...)):
44
  contents = await file.read()
45
  image = Image.open(io.BytesIO(contents))
46
+ image = transform(image).unsqueeze(0).to(device) # Add batch dimension and move to device
47
 
48
  with torch.no_grad():
49
  outputs = model(image)
50
 
51
  # Calculate softmax probabilities
52
+ probabilities = torch.softmax(outputs.logits, dim=1).cpu().numpy()[0]
53
 
54
  # Get predicted class index and its probability
55
+ predicted_idx = torch.argmax(torch.tensor(probabilities)).item()
56
  predicted_label = class_labels[predicted_idx]
57
+ predicted_probability = probabilities[predicted_idx]
58
 
59
+ # Check if the predicted probability meets the class-specific threshold
60
+ if predicted_probability < thresholds[predicted_idx]:
61
+ return {'predicted_class': 'uncertain', 'accuracy': predicted_probability}
62
+ else:
63
+ return {'predicted_class': predicted_label, 'accuracy': predicted_probability}