mrs83 commited on
Commit
1392818
·
verified ·
1 Parent(s): d7d9be4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +2 -151
app.py CHANGED
@@ -1,153 +1,4 @@
1
- import gradio as gr
2
- import torch
3
- from PIL import Image
4
- from datasets import load_dataset
5
- import random
6
-
7
- from skincancer_vit.model import SkinCancerViTModel
8
-
9
- HF_MODEL_REPO = "ethicalabs/SkinCancerViT"
10
- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
11
-
12
-
13
- print(f"Loading SkinCancerViT model from {HF_MODEL_REPO} to {DEVICE}...")
14
-
15
- model = SkinCancerViTModel.from_pretrained(HF_MODEL_REPO)
16
- model.to(DEVICE)
17
- model.eval() # Set to evaluation mode
18
- print("Model loaded successfully.")
19
-
20
- print("Loading 'marmal88/skin_cancer' dataset for random samples...")
21
- dataset = load_dataset("marmal88/skin_cancer", split="test")
22
- print("Dataset loaded successfully.")
23
-
24
-
25
- def predict_uploaded_image(image: Image.Image, age: int, localization: str) -> str:
26
- """
27
- Handles prediction for an uploaded image with user-provided tabular data.
28
- """
29
- if model is None:
30
- return "Error: Model not loaded. Please check the console for details."
31
- if image is None:
32
- return "Please upload an image."
33
- if age is None:
34
- return "Please enter an age."
35
- if not localization:
36
- return "Please select a localization."
37
-
38
- try:
39
- # Call the model's full_predict method
40
- predicted_dx, confidence = model.full_predict(
41
- raw_image=image, raw_age=age, raw_localization=localization, device=DEVICE
42
- )
43
- return f"Predicted Diagnosis: **{predicted_dx}** (Confidence: {confidence:.4f})"
44
- except Exception as e:
45
- return f"Prediction Error: {e}"
46
-
47
-
48
- # --- Prediction Function for Random Sample ---
49
- def predict_random_sample() -> str:
50
- """
51
- Fetches a random sample from the dataset and performs prediction.
52
- """
53
- if model is None:
54
- return "Error: Model not loaded. Please check the console for details."
55
- if dataset is None:
56
- return "Error: Dataset not loaded. Cannot select random sample."
57
-
58
- try:
59
- # Select a random sample from the dataset
60
- random_idx = random.randint(0, len(dataset) - 1)
61
- sample = dataset[random_idx]
62
-
63
- sample_image = sample["image"]
64
- sample_age = sample["age"]
65
- sample_localization = sample["localization"]
66
- sample_true_dx = sample["dx"]
67
-
68
- # Call the model's full_predict method
69
- predicted_dx, confidence = model.full_predict(
70
- raw_image=sample_image,
71
- raw_age=sample_age,
72
- raw_localization=sample_localization,
73
- device=DEVICE,
74
- )
75
-
76
- # Return a formatted string with all information
77
- result_str = (
78
- f"**Random Sample Details:**\n"
79
- f"- Age: {sample_age}\n"
80
- f"- Localization: {sample_localization}\n"
81
- f"- True Diagnosis: **{sample_true_dx}**\n\n"
82
- f"**Model Prediction:**\n"
83
- f"- Predicted Diagnosis: **{predicted_dx}**\n"
84
- f"- Confidence: {confidence:.4f}\n"
85
- f"- Correct Prediction: {'✅ Yes' if predicted_dx == sample_true_dx else '❌ No'}"
86
- )
87
- return sample_image, result_str
88
- except Exception as e:
89
- return None, f"Prediction Error on Random Sample: {e}"
90
-
91
-
92
- # --- Gradio Interface ---
93
- with gr.Blocks(title="Skin Cancer ViT Predictor") as demo:
94
- gr.Markdown(
95
- """
96
- # Skin Cancer ViT Predictor
97
- This application demonstrates the `SkinCancerViT` multimodal model for skin cancer diagnosis.
98
- It can take an uploaded image with patient metadata or predict on a random sample from the dataset.
99
- **Disclaimer:** This tool is for demonstration and research purposes only and should not be used for medical diagnosis.
100
- """
101
- )
102
-
103
- with gr.Tab("Predict on Random Sample"):
104
- gr.Markdown("## Get a Prediction from a Random Sample in the Test Set")
105
- random_sample_button = gr.Button("Get Random Sample Prediction")
106
-
107
- # Modified output components for random sample tab
108
- with gr.Row():
109
- output_random_image = gr.Image(
110
- type="pil", label="Random Sample Image", height=250, width=250
111
- )
112
- output_random_details = gr.Markdown(
113
- "Random sample details and prediction will appear here."
114
- )
115
-
116
- random_sample_button.click(
117
- fn=predict_random_sample,
118
- inputs=[],
119
- outputs=[
120
- output_random_image,
121
- output_random_details,
122
- ], # Map to both image and markdown outputs
123
- )
124
-
125
- with gr.Tab("Upload Image & Predict"):
126
- gr.Markdown("## Upload Your Image and Provide Patient Data")
127
- with gr.Row():
128
- image_input = gr.Image(
129
- type="pil", label="Upload Skin Lesion Image (224x224 preferred)"
130
- )
131
- with gr.Column():
132
- age_input = gr.Number(
133
- label="Patient Age", minimum=0, maximum=120, step=1
134
- )
135
- # Ensure these localizations match your training data categories
136
- localization_input = gr.Dropdown(
137
- model.config.localization_to_id.keys(),
138
- label="Lesion Localization",
139
- value="unknown", # Default value
140
- )
141
- predict_button = gr.Button("Get Prediction")
142
-
143
- output_upload = gr.Markdown("Prediction will appear here.")
144
-
145
- predict_button.click(
146
- fn=predict_uploaded_image,
147
- inputs=[image_input, age_input, localization_input],
148
- outputs=output_upload,
149
- )
150
 
151
  if __name__ == "__main__":
152
- demo.launch(share=False)
153
-
 
1
+ from skincancer_vit.gradio_app import *
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
  if __name__ == "__main__":
4
+ demo.launch(share=False)